Skip to content
Snippets Groups Projects
Commit 55db4453 authored by Behzad Safaei's avatar Behzad Safaei
Browse files

Add bit-field implementation for property flags

parent e9ad898a
No related branches found
No related tags found
No related merge requests found
...@@ -17,6 +17,7 @@ class PairsAcessor: ...@@ -17,6 +17,7 @@ class PairsAcessor:
if self.target.is_gpu(): if self.target.is_gpu():
self.print("namespace pairs::internal{") self.print("namespace pairs::internal{")
self.print.add_indent(4) self.print.add_indent(4)
self.PropFlags_struct()
self.DeviceProps_struct() self.DeviceProps_struct()
self.HostProps_struct() self.HostProps_struct()
self.print.add_indent(-4) self.print.add_indent(-4)
...@@ -61,6 +62,17 @@ class PairsAcessor: ...@@ -61,6 +62,17 @@ class PairsAcessor:
self.print("};") self.print("};")
self.print("") self.print("")
def PropFlags_struct(self):
self.print("struct PropFlags{")
self.print.add_indent(4)
for p in self.sim.properties:
pname = p.name()
self.print(f"unsigned int {pname} : 1;")
self.print.add_indent(-4)
self.print("};")
self.print("")
def DeviceProps_struct(self): def DeviceProps_struct(self):
self.print("struct DeviceProps{") self.print("struct DeviceProps{")
self.print.add_indent(4) self.print.add_indent(4)
...@@ -74,13 +86,10 @@ class PairsAcessor: ...@@ -74,13 +86,10 @@ class PairsAcessor:
pname = p.name() pname = p.name()
tkw = Types.c_keyword(self.sim, p.type()) tkw = Types.c_keyword(self.sim, p.type())
self.print(f"{tkw} *{pname}_d;") self.print(f"{tkw} *{pname}_d;")
self.print("") self.print("")
self.print("//Property device flag pointers")
for p in self.sim.properties: self.print("//Device pointer to device flags")
pname = p.name() self.print(f"PropFlags *prop_device_flags_d;")
tkw = Types.c_keyword(self.sim, Types.Boolean)
self.print(f"{tkw} *{pname}_device_flag_d;")
self.print("") self.print("")
self.print("//Feature properties on device are global") self.print("//Feature properties on device are global")
...@@ -103,17 +112,11 @@ class PairsAcessor: ...@@ -103,17 +112,11 @@ class PairsAcessor:
self.print("") self.print("")
self.print("//Property host flags") self.print("//Property host flags")
for p in self.sim.properties: self.print("PropFlags prop_host_flags = {};")
pname = p.name()
tkw = Types.c_keyword(self.sim, Types.Boolean)
self.print(f"{tkw} {pname}_host_flag = false;")
self.print("") self.print("")
self.print("//Property device flags") self.print("//Property device flags")
for p in self.sim.properties: self.print("PropFlags prop_device_flags_h = {};")
pname = p.name()
tkw = Types.c_keyword(self.sim, Types.Boolean)
self.print(f"{tkw} {pname}_device_flag_h = false;")
self.print("") self.print("")
self.print("//Feature property host pointers are in PairsObjects") self.print("//Feature property host pointers are in PairsObjects")
...@@ -163,12 +166,9 @@ class PairsAcessor: ...@@ -163,12 +166,9 @@ class PairsAcessor:
self.print(f"hp = new pairs::internal::HostProps;") self.print(f"hp = new pairs::internal::HostProps;")
self.print(f"dp_h = new pairs::internal::DeviceProps;") self.print(f"dp_h = new pairs::internal::DeviceProps;")
self.print(f"cudaMalloc(&dp_d, sizeof(pairs::internal::DeviceProps));") self.print(f"cudaMalloc(&dp_d, sizeof(pairs::internal::DeviceProps));")
self.print(f"cudaMalloc(&(dp_h->prop_device_flags_d), sizeof(pairs::internal::PropFlags));")
self.print(f"cudaMemset(dp_h->prop_device_flags_d, 0, sizeof(pairs::internal::PropFlags));")
for p in self.sim.properties:
pname = p.name()
tkw = Types.c_keyword(self.sim, Types.Boolean)
self.print(f"cudaMalloc(&(dp_h->{pname}_device_flag_d), sizeof({tkw}));")
self.print("this->update();") self.print("this->update();")
self.print.add_indent(-4) self.print.add_indent(-4)
self.print("}") self.print("}")
...@@ -182,12 +182,7 @@ class PairsAcessor: ...@@ -182,12 +182,7 @@ class PairsAcessor:
self.print(f"{self.host_attr} void end(){{") self.print(f"{self.host_attr} void end(){{")
if self.target.is_gpu(): if self.target.is_gpu():
self.print.add_indent(4) self.print.add_indent(4)
self.print(f"cudaFree(dp_h->prop_device_flags_d);")
for p in self.sim.properties:
pname = p.name()
tkw = Types.c_keyword(self.sim, Types.Boolean)
self.print(f"cudaFree(dp_h->{pname}_device_flag_d);")
self.print(f"cudaFree(dp_d);") self.print(f"cudaFree(dp_d);")
self.print(f"delete dp_h;") self.print(f"delete dp_h;")
self.print(f"delete hp;") self.print(f"delete hp;")
...@@ -264,12 +259,16 @@ class PairsAcessor: ...@@ -264,12 +259,16 @@ class PairsAcessor:
def setter_body(self, prop, device=False): def setter_body(self, prop, device=False):
self.print.add_indent(4) self.print.add_indent(4)
ptr = self.generate_ref_name(prop, device) ptr = self.generate_ref_name(prop, device)
pname = prop.name()
if isinstance(prop, Property): if isinstance(prop, Property):
idx = "i" idx = "i"
flag = f"dp_d->prop_device_flags_d->{pname}" if device else f"hp->prop_host_flags.{pname}"
elif isinstance(prop, FeatureProperty): elif isinstance(prop, FeatureProperty):
fname = prop.feature().name() fname = prop.feature().name()
idx = f"({prop.feature().nkinds()}*{fname}1 + {fname}2)" idx = f"({prop.feature().nkinds()}*{fname}1 + {fname}2)"
flag = f"hp->{pname}_host_flag"
if Types.is_scalar(prop.type()): if Types.is_scalar(prop.type()):
self.print(f"{ptr}[{idx}] = value;") self.print(f"{ptr}[{idx}] = value;")
...@@ -279,8 +278,6 @@ class PairsAcessor: ...@@ -279,8 +278,6 @@ class PairsAcessor:
self.print(f"{ptr}[{idx}*{nelems} + {n}] = value[{n}];") self.print(f"{ptr}[{idx}*{nelems} + {n}] = value[{n}];")
if self.target.is_gpu(): if self.target.is_gpu():
pname = prop.name()
flag = f"*(dp_d->{pname}_device_flag_d)" if device else f"hp->{pname}_host_flag"
self.print(f"{flag} = true;") self.print(f"{flag} = true;")
self.print.add_indent(-4) self.print.add_indent(-4)
...@@ -316,7 +313,7 @@ class PairsAcessor: ...@@ -316,7 +313,7 @@ class PairsAcessor:
def sync_ctx_enum(self): def sync_ctx_enum(self):
self.print("enum SyncContext{") self.print("enum SyncContext{")
self.print(" Host = 0,") self.print(" Host = 0,")
self.print(" Device") self.print(" Device = 1")
self.print("};") self.print("};")
self.print("") self.print("")
...@@ -330,24 +327,24 @@ class PairsAcessor: ...@@ -330,24 +327,24 @@ class PairsAcessor:
if self.target.is_gpu(): if self.target.is_gpu():
self.print.add_indent(4) self.print.add_indent(4)
self.print(f"cudaMemcpy(&(hp->{pname}_device_flag_h), dp_h->{pname}_device_flag_d, sizeof(bool), cudaMemcpyDeviceToHost);") self.print(f"cudaMemcpy(&(hp->prop_device_flags_h), dp_h->prop_device_flags_d, sizeof(pairs::internal::PropFlags), cudaMemcpyDeviceToHost);")
self.print("") self.print("")
##################################################################################################################### #####################################################################################################################
##################################################################################################################### #####################################################################################################################
self.print(f"if (hp->{pname}_host_flag && hp->{pname}_device_flag_h){{") self.print(f"if (hp->prop_host_flags.{pname} && hp->prop_device_flags_h.{pname}){{")
self.print(f" PAIRS_ERROR(\"OUT OF SYNC 1! Both host and device versions of {pname} are in a modified state.\\n\");") self.print(f" PAIRS_ERROR(\"OUT OF SYNC 1! Both host and device versions of {pname} are in a modified state.\\n\");")
self.print(" exit(-1);") self.print(" exit(-1);")
self.print("}") self.print("}")
self.print(f"else if(sync_ctx==Host && overwrite==false){{") self.print(f"else if(sync_ctx==Host && overwrite==false){{")
self.print(f" if (hp->{pname}_host_flag && !ps->pairs_runtime->getPropFlags()->isHostFlagSet({pid})){{") self.print(f" if (hp->prop_host_flags.{pname} && !ps->pairs_runtime->getPropFlags()->isHostFlagSet({pid})){{")
self.print(f" PAIRS_ERROR(\"OUT OF SYNC 2! Did you forget to sync{funcname}(Host) before calling set{funcname} from host? Use sync{funcname}(Host,true) if you want to overwrite {pname} values in host.\\n\");") self.print(f" PAIRS_ERROR(\"OUT OF SYNC 2! Did you forget to sync{funcname}(Host) before calling set{funcname} from host? Use sync{funcname}(Host,true) if you want to overwrite {pname} values in host.\\n\");")
self.print(" exit(-1);") self.print(" exit(-1);")
self.print(" }") self.print(" }")
self.print("}") self.print("}")
self.print(f"else if(sync_ctx==Device && overwrite==false){{") self.print(f"else if(sync_ctx==Device && overwrite==false){{")
self.print(f" if (hp->{pname}_device_flag_h && !ps->pairs_runtime->getPropFlags()->isDeviceFlagSet({pid})){{") self.print(f" if (hp->prop_device_flags_h.{pname} && !ps->pairs_runtime->getPropFlags()->isDeviceFlagSet({pid})){{")
self.print(f" PAIRS_ERROR(\"OUT OF SYNC 3! Did you forget to sync{funcname}(Device) before calling set{funcname} from device? Use sync{funcname}(Device,true) if you want to overwrite {pname} values in device.\\n\");") self.print(f" PAIRS_ERROR(\"OUT OF SYNC 3! Did you forget to sync{funcname}(Device) before calling set{funcname} from device? Use sync{funcname}(Device,true) if you want to overwrite {pname} values in device.\\n\");")
self.print(" exit(-1);") self.print(" exit(-1);")
self.print(" }") self.print(" }")
...@@ -357,12 +354,12 @@ class PairsAcessor: ...@@ -357,12 +354,12 @@ class PairsAcessor:
##################################################################################################################### #####################################################################################################################
##################################################################################################################### #####################################################################################################################
self.print(f"if (hp->{pname}_host_flag){{") self.print(f"if (hp->prop_host_flags.{pname}){{")
self.print(f" ps->pairs_runtime->getPropFlags()->setHostFlag({pid});") self.print(f" ps->pairs_runtime->getPropFlags()->setHostFlag({pid});")
self.print(f" ps->pairs_runtime->getPropFlags()->clearDeviceFlag({pid});") self.print(f" ps->pairs_runtime->getPropFlags()->clearDeviceFlag({pid});")
self.print("}") self.print("}")
self.print(f"else if (hp->{pname}_device_flag_h){{") self.print(f"else if (hp->prop_device_flags_h.{pname}){{")
self.print(f" ps->pairs_runtime->getPropFlags()->setDeviceFlag({pid});") self.print(f" ps->pairs_runtime->getPropFlags()->setDeviceFlag({pid});")
self.print(f" ps->pairs_runtime->getPropFlags()->clearHostFlag({pid});") self.print(f" ps->pairs_runtime->getPropFlags()->clearHostFlag({pid});")
self.print("}") self.print("}")
...@@ -381,9 +378,9 @@ class PairsAcessor: ...@@ -381,9 +378,9 @@ class PairsAcessor:
self.print("}") self.print("}")
self.print("") self.print("")
self.print(f"hp->{pname}_host_flag = false;") self.print(f"hp->prop_host_flags.{pname} = false;")
self.print(f"hp->{pname}_device_flag_h = false;") self.print(f"hp->prop_device_flags_h.{pname} = false;")
self.print(f"cudaMemcpy(dp_h->{pname}_device_flag_d, &(hp->{pname}_device_flag_h), sizeof(bool), cudaMemcpyHostToDevice);") self.print(f"cudaMemcpy(dp_h->prop_device_flags_d, &(hp->prop_device_flags_h), sizeof(pairs::internal::PropFlags), cudaMemcpyHostToDevice);")
self.print.add_indent(-4) self.print.add_indent(-4)
self.print("}") self.print("}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment