summaryrefslogtreecommitdiff
path: root/modules/remote-sycl
diff options
context:
space:
mode:
authorClaudius "keldu" Holeksa <mail@keldu.de>2024-06-14 14:33:22 +0200
committerClaudius "keldu" Holeksa <mail@keldu.de>2024-06-14 14:33:22 +0200
commit5329652f839b99b95d63cd471ff73d251f74d911 (patch)
tree20797eb65f3e48686979362f828a6e07b21e9b5a /modules/remote-sycl
parent57f6eacfcdbdba31185eb66b9a573a8923eecf16 (diff)
Fixed calc of sycl vals
Diffstat (limited to 'modules/remote-sycl')
-rw-r--r--modules/remote-sycl/c++/remote.hpp47
-rw-r--r--modules/remote-sycl/examples/SConscript1
-rw-r--r--modules/remote-sycl/examples/sycl_basic.cpp41
-rw-r--r--modules/remote-sycl/examples/sycl_basic_kernel.cpp3
4 files changed, 79 insertions, 13 deletions
diff --git a/modules/remote-sycl/c++/remote.hpp b/modules/remote-sycl/c++/remote.hpp
index bcc8a3c..4510237 100644
--- a/modules/remote-sycl/c++/remote.hpp
+++ b/modules/remote-sycl/c++/remote.hpp
@@ -22,22 +22,34 @@ class remote_data<T, Encoding, Storage, rmt::Sycl> {
private:
id<T> id_;
id_map<T,Encoding,rmt::Sycl>* map_;
- cl::sycl::queue* queue_;
public:
/**
* Main constructor
*/
remote_data(const id<T>& id, id_map<T, Encoding, rmt::Sycl>& map, cl::sycl::queue& queue__):
id_{id},
- map_{&map},
- queue_{&queue__}
+ map_{&map}
{}
/**
* Wait for the data
*/
error_or<data<T,Encoding,Storage>> wait(){
+ auto eov = map_->find(id_);
+ if(eov.is_error()){
+ auto& err = eov.get_error();
+ return std::move(err);
+ }
+ auto& val = eov.get_value();
+ std::cout<<"Values Sycl in Map: "<<val->size()<<std::endl;
+ {
+ auto eocop = val->template copy_to_host<storage::Default>();
+ if(eocop.is_error()){
+ return eocop;
+ }
+ return eocop.get_value();
+ }
}
/**
@@ -87,7 +99,9 @@ public:
{
if(!device_data_){
total_length_ = 0u;
+ return;
}
+ queue_->wait();
}
template<typename Encoding, typename Storage>
@@ -102,6 +116,7 @@ public:
return;
}
queue_->template copy<data<T,encode::Native,storage::Default>>(&from.at(0), device_data_, total_length_);
+ queue_->wait();
}
data(const data<Schema, encode::Native, rmt::Sycl>& from):
@@ -116,7 +131,10 @@ public:
// device_data_ = cl::sycl::malloc_device<typename native_data_type<T>::type>(from.size(), *queue_);
if(!device_data_){
total_length_ = 0u;
+ return;
}
+
+ queue_->template copy<data<T,encode::Native,storage::Default>>(from.device_data_, device_data_, total_length_);
}
data<Schema, encode::Native, rmt::Sycl>& operator=(const data<Schema, encode::Native, rmt::Sycl>& rhs) {
@@ -158,7 +176,20 @@ public:
}
}
- data<T, encode::Native, saw::storage::Default>& at(uint64_t i){
+ /**
+ * Allocate appropriate meta data and then copy to host
+ */
+ template<typename Storage>
+ error_or<data<Schema, encode::Native, Storage>> copy_to_host() const {
+ data<Schema,encode::Native, Storage> data_{total_length_};
+
+ /// TODO Check success
+ queue_->template copy<data<T,encode::Native,storage::Default>>(device_data_, &data_.at(0), total_length_);
+ queue_->wait();
+ return data_;
+ }
+
+ data<T, encode::Native, storage::Default>& at(uint64_t i){
//typename native_data_type<T>::type& at(uint64_t i){
return device_data_[i];
}
@@ -247,6 +278,9 @@ public:
return std::get<id_map<T,Encoding,rmt::Sycl>>(storage_.maps).next_free_id();
}
+ /**
+ * Main constructor
+ */
rpc_server(interface<Iface, Encoding, rmt::Sycl, InterfaceCtxT> cl_iface):
cmd_queue_{},
cl_interface_{std::move(cl_iface)},
@@ -255,12 +289,11 @@ public:
template<typename IdT, typename Storage>
remote_data<IdT, Encoding, Storage, rmt::Sycl> request_data(id<IdT> dat_id){
- /// @TODO Fix so I can receive data
- return {dat_id, std::get<id_map<IdT, Encoding,rmt::Sycl>>(storage_.maps)};
+ return {dat_id, std::get<id_map<IdT,Encoding,rmt::Sycl>>(storage_.maps), cmd_queue_};
}
/**
- * Rpc call
+ * Rpc call based on the name
*/
template<string_literal Name, typename ClientAllocation>
error_or<
diff --git a/modules/remote-sycl/examples/SConscript b/modules/remote-sycl/examples/SConscript
index 02e528b..015b492 100644
--- a/modules/remote-sycl/examples/SConscript
+++ b/modules/remote-sycl/examples/SConscript
@@ -15,6 +15,7 @@ examples_env = env.Clone();
examples_sycl_env = examples_env.Clone();
examples_sycl_env['CXX'] = 'acpp';
+examples_sycl_env['CXXFLAGS'] += ['-O2'];
examples_env.sources = sorted(glob.glob(dir_path + "/*.cpp"))
examples_env.headers = sorted(glob.glob(dir_path + "/*.hpp"))
diff --git a/modules/remote-sycl/examples/sycl_basic.cpp b/modules/remote-sycl/examples/sycl_basic.cpp
index 2e9a4f8..486aca1 100644
--- a/modules/remote-sycl/examples/sycl_basic.cpp
+++ b/modules/remote-sycl/examples/sycl_basic.cpp
@@ -22,8 +22,12 @@ int main(){
saw::rpc_client<schema::BasicInterface, saw::encode::Native, saw::storage::Default, saw::rmt::Sycl> client{rpc_server};
saw::id<schema::Array<schema::UInt64>> id_zero{0u};
+ saw::data<schema::Array<schema::UInt64>, saw::encode::Native> ex_data{1u};
+ ex_data.at(0u).set(50u);
{
- auto eov = client.template call<"increment">(saw::data<schema::Array<schema::UInt64>, saw::encode::Native>{1u});
+ auto eov = client.template call<"increment">(
+ ex_data
+ );
if(eov.is_error()){
auto& err = eov.get_error();
std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl;
@@ -32,17 +36,46 @@ int main(){
id_zero = eov.get_value();
}
{
+ auto rmt_data = rpc_server.request_data<schema::Array<schema::UInt64>, saw::storage::Default>(id_zero);
+ auto eo_rd = rmt_data.wait();
+ if(eo_rd.is_error()){
+ auto& err = eo_rd.get_error();
+ std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl;
+ return -2;
+ }
+
+ auto& val = eo_rd.get_value();
+ std::cout<<"Values: "<<val.size()<<"\n";
+ for(uint64_t i = 0; i < val.size(); ++i){
+ std::cout<<val.at(i).get()<<'\t';
+ }
+ std::cout<<std::endl;
+ }
+ saw::id<schema::Array<schema::UInt64>> id_one{1u};
+ {
auto eov = client.template call<"increment">(id_zero);
if(eov.is_error()){
auto& err = eov.get_error();
std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl;
return -2;
}
- auto& val = eov.get_value();
- std::cout<<"Value: "<<val.get_value()<<std::endl;
+ id_one = eov.get_value();
}
{
- // auto eo_rd = rpc_server.request_data(id_one);
+ auto rmt_data = rpc_server.request_data<schema::Array<schema::UInt64>, saw::storage::Default>(id_one);
+ auto eo_rd = rmt_data.wait();
+ if(eo_rd.is_error()){
+ auto& err = eo_rd.get_error();
+ std::cerr<<"Error: "<<err.get_category()<<" : "<<err.get_message()<<std::endl;
+ return -2;
+ }
+
+ auto& val = eo_rd.get_value();
+ std::cout<<"Values: "<<val.size()<<"\n";
+ for(uint64_t i = 0; i < val.size(); ++i){
+ std::cout<<val.at(i).get()<<'\t';
+ }
+ std::cout<<std::endl;
}
return 0;
diff --git a/modules/remote-sycl/examples/sycl_basic_kernel.cpp b/modules/remote-sycl/examples/sycl_basic_kernel.cpp
index 888f905..03f0bac 100644
--- a/modules/remote-sycl/examples/sycl_basic_kernel.cpp
+++ b/modules/remote-sycl/examples/sycl_basic_kernel.cpp
@@ -4,9 +4,8 @@ saw::rpc_server<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl> lis
saw::interface<schema::BasicInterface, saw::encode::Native, saw::rmt::Sycl, cl::sycl::queue*> iface{
[](saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> in, cl::sycl::queue* q) -> saw::data<saw::schema::Array<saw::schema::UInt64>, saw::encode::Native, saw::rmt::Sycl> {
-
q->submit([&](cl::sycl::handler& h){
- h.parallel_for(cl::sycl::range<1>(1u), [&] (cl::sycl::id<1> it){
+ h.single_task([&] (){
in.at(0u).set(in.at(0u).get() + 1u);
});
});