Skip to content

Commit

Permalink
Pass policy into param functions by value
Browse files Browse the repository at this point in the history
This makes this use normal overloading resolution instead of
trying to take everything through the RAJA::detail namespace.
Theoretically this should allow a user could make their own param.
  • Loading branch information
MrBurmark committed Feb 18, 2025
1 parent 30a8e7b commit 3a4a65b
Show file tree
Hide file tree
Showing 25 changed files with 177 additions and 135 deletions.
43 changes: 23 additions & 20 deletions include/RAJA/pattern/params/forall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,45 +43,45 @@ struct ForallParamPack
private:
// Init
template<typename EXEC_POL, camp::idx_t... Seq, typename... Args>
static constexpr void params_init(EXEC_POL,
static constexpr void params_init(EXEC_POL const& pol,
camp::idx_seq<Seq...>,
ForallParamPack& f_params,
Args&&... args)
{
CAMP_EXPAND(param_init<EXEC_POL>(camp::get<Seq>(f_params.param_tup),
std::forward<Args>(args)...));
CAMP_EXPAND(param_init(pol, camp::get<Seq>(f_params.param_tup),
std::forward<Args>(args)...));
}

// Combine
template<typename EXEC_POL, camp::idx_t... Seq>
RAJA_HOST_DEVICE static constexpr void params_combine(
EXEC_POL,
EXEC_POL const& pol,
camp::idx_seq<Seq...>,
ForallParamPack& out,
const ForallParamPack& in)
{
CAMP_EXPAND(param_combine<EXEC_POL>(camp::get<Seq>(out.param_tup),
camp::get<Seq>(in.param_tup)));
CAMP_EXPAND(param_combine(pol, camp::get<Seq>(out.param_tup),
camp::get<Seq>(in.param_tup)));
}

template<typename EXEC_POL, camp::idx_t... Seq>
RAJA_HOST_DEVICE static constexpr void params_combine(
EXEC_POL,
EXEC_POL const& pol,
camp::idx_seq<Seq...>,
ForallParamPack& f_params)
{
CAMP_EXPAND(param_combine<EXEC_POL>(camp::get<Seq>(f_params.param_tup)));
CAMP_EXPAND(param_combine(pol, camp::get<Seq>(f_params.param_tup)));
}

// Resolve
template<typename EXEC_POL, camp::idx_t... Seq, typename... Args>
static constexpr void params_resolve(EXEC_POL,
static constexpr void params_resolve(EXEC_POL const& pol,
camp::idx_seq<Seq...>,
ForallParamPack& f_params,
Args&&... args)
{
CAMP_EXPAND(param_resolve<EXEC_POL>(camp::get<Seq>(f_params.param_tup),
std::forward<Args>(args)...));
CAMP_EXPAND(param_resolve(pol, camp::get<Seq>(f_params.param_tup),
std::forward<Args>(args)...));
}

// Used to construct the argument TYPES that will be invoked with the lambda.
Expand Down Expand Up @@ -155,32 +155,35 @@ struct ParamMultiplexer
typename... Params,
typename... Args,
typename FP = ForallParamPack<Params...>>
static void constexpr params_init(ForallParamPack<Params...>& f_params,
Args&&... args)
static void constexpr params_init(EXEC_POL const& pol,
ForallParamPack<Params...>& f_params,
Args&&... args)
{
FP::params_init(EXEC_POL(), typename FP::params_seq(), f_params,
FP::params_init(pol, typename FP::params_seq(), f_params,
std::forward<Args>(args)...);
}

template<typename EXEC_POL,
typename... Params,
typename... Args,
typename FP = ForallParamPack<Params...>>
static void constexpr params_combine(ForallParamPack<Params...>& f_params,
Args&&... args)
static void constexpr params_combine(EXEC_POL const& pol,
ForallParamPack<Params...>& f_params,
Args&&... args)
{
FP::params_combine(EXEC_POL(), typename FP::params_seq(), f_params,
FP::params_combine(pol, typename FP::params_seq(), f_params,
std::forward<Args>(args)...);
}

template<typename EXEC_POL,
typename... Params,
typename... Args,
typename FP = ForallParamPack<Params...>>
static void constexpr params_resolve(ForallParamPack<Params...>& f_params,
Args&&... args)
static void constexpr params_resolve(EXEC_POL const& pol,
ForallParamPack<Params...>& f_params,
Args&&... args)
{
FP::params_resolve(EXEC_POL(), typename FP::params_seq(), f_params,
FP::params_resolve(pol, typename FP::params_seq(), f_params,
std::forward<Args>(args)...);
}
};
Expand Down
14 changes: 7 additions & 7 deletions include/RAJA/policy/cuda/forall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ __launch_bounds__(BlockSize, BlocksPerSM) __global__
{
RAJA::expt::invoke_body(f_params, body, idx[ii]);
}
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
}

///
Expand Down Expand Up @@ -474,7 +474,7 @@ __global__ void forallp_cuda_kernel(LOOP_BODY loop_body,
{
RAJA::expt::invoke_body(f_params, body, idx[ii]);
}
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
}

template<
Expand Down Expand Up @@ -565,7 +565,7 @@ __launch_bounds__(BlockSize, BlocksPerSM) __global__
{
RAJA::expt::invoke_body(f_params, body, idx[ii]);
}
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
}

///
Expand Down Expand Up @@ -597,7 +597,7 @@ __global__ void forallp_cuda_kernel(LOOP_BODY loop_body,
{
RAJA::expt::invoke_body(f_params, body, idx[ii]);
}
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
}

} // namespace impl
Expand Down Expand Up @@ -712,7 +712,7 @@ forall_impl(resources::Cuda cuda_res,
IterationGetter,
Concretizer,
BlocksPerSM,
Async> const&,
Async> const& pol,
Iterable&& iter,
LoopBody&& loop_body,
ForallParam f_params)
Expand Down Expand Up @@ -764,7 +764,7 @@ forall_impl(resources::Cuda cuda_res,
launch_info.res = cuda_res;

{
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(f_params, launch_info);
RAJA::expt::ParamMultiplexer::params_init(pol, f_params, launch_info);

//
// Privatize the loop_body, using make_launch_body to setup reductions
Expand All @@ -781,7 +781,7 @@ forall_impl(resources::Cuda cuda_res,
RAJA::cuda::launch(func, dims.blocks, dims.threads, args, shmem, cuda_res,
Async);

RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(f_params, launch_info);
RAJA::expt::ParamMultiplexer::params_resolve(pol, f_params, launch_info);
}

RAJA_FT_END;
Expand Down
18 changes: 9 additions & 9 deletions include/RAJA/policy/cuda/launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ __global__ void launch_new_reduce_global_fcn(BODY body_in,
RAJA::expt::invoke_body(reduce_params, body, ctx);

// Using a flatten global policy as we may use all dimensions
RAJA::expt::ParamMultiplexer::params_combine<RAJA::cuda_flatten_global_xyz_direct>(
RAJA::expt::ParamMultiplexer::params_combine(RAJA::cuda_flatten_global_xyz_direct{},
reduce_params);
}

Expand Down Expand Up @@ -186,7 +186,7 @@ struct LaunchExecute<
{
using EXEC_POL = RAJA::policy::cuda::cuda_launch_explicit_t<
async, named_usage::unspecified, named_usage::unspecified>;
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(launch_reducers,
RAJA::expt::ParamMultiplexer::params_init(EXEC_POL{}, launch_reducers,
launch_info);


Expand All @@ -204,7 +204,7 @@ struct LaunchExecute<
RAJA::cuda::launch(func, gridSize, blockSize, args, shared_mem_size,
cuda_res, async, kernel_name);

RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(launch_reducers,
RAJA::expt::ParamMultiplexer::params_resolve(EXEC_POL{}, launch_reducers,
launch_info);
}

Expand Down Expand Up @@ -253,7 +253,7 @@ __launch_bounds__(num_threads, BLOCKS_PER_SM) __global__
RAJA::expt::invoke_body(reduce_params, body, ctx);

// Using a flatten global policy as we may use all dimensions
RAJA::expt::ParamMultiplexer::params_combine<RAJA::cuda_flatten_global_xyz_direct>(
RAJA::expt::ParamMultiplexer::params_combine(RAJA::cuda_flatten_global_xyz_direct{},
reduce_params);
}

Expand Down Expand Up @@ -378,11 +378,11 @@ struct LaunchExecute<
launch_info.dynamic_smem = &shared_mem_size;
launch_info.res = cuda_res;
{

// Use a generic block size policy here to match that used in params_combine
using EXEC_POL =
RAJA::policy::cuda::cuda_launch_explicit_t<async, nthreads,
BLOCKS_PER_SM>;
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(launch_reducers,
RAJA::policy::cuda::cuda_launch_explicit_t<
async, named_usage::unspecified, named_usage::unspecified>;
RAJA::expt::ParamMultiplexer::params_init(EXEC_POL{}, launch_reducers,
launch_info);

//
Expand All @@ -399,7 +399,7 @@ struct LaunchExecute<
RAJA::cuda::launch(func, gridSize, blockSize, args, shared_mem_size,
cuda_res, async, kernel_name);

RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(launch_reducers,
RAJA::expt::ParamMultiplexer::params_resolve(EXEC_POL{}, launch_reducers,
launch_info);
}

Expand Down
5 changes: 4 additions & 1 deletion include/RAJA/policy/cuda/params/reduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ namespace detail
// Init
template<typename EXEC_POL, typename OP, typename T, typename VOp>
camp::concepts::enable_if<type_traits::is_cuda_policy<EXEC_POL>> param_init(
EXEC_POL const&,
Reducer<OP, T, VOp>& red,
RAJA::cuda::detail::cudaInfo& ci)
{
Expand All @@ -34,7 +35,8 @@ camp::concepts::enable_if<type_traits::is_cuda_policy<EXEC_POL>> param_init(
template<typename EXEC_POL, typename OP, typename T, typename VOp>
RAJA_HOST_DEVICE camp::concepts::enable_if<
type_traits::is_cuda_policy<EXEC_POL>>
param_combine(Reducer<OP, T, VOp>& red)
param_combine(EXEC_POL const&,
Reducer<OP, T, VOp>& red)
{
RAJA::cuda::impl::expt::grid_reduce<typename EXEC_POL::IterationGetter, OP>(
red.devicetarget, red.getVal(), red.device_mem, red.device_count);
Expand All @@ -43,6 +45,7 @@ param_combine(Reducer<OP, T, VOp>& red)
// Resolve
template<typename EXEC_POL, typename OP, typename T, typename VOp>
camp::concepts::enable_if<type_traits::is_cuda_policy<EXEC_POL>> param_resolve(
EXEC_POL const&,
Reducer<OP, T, VOp>& red,
RAJA::cuda::detail::cudaInfo& ci)
{
Expand Down
14 changes: 7 additions & 7 deletions include/RAJA/policy/hip/forall.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ __launch_bounds__(BlockSize, 1) __global__
{
RAJA::expt::invoke_body(f_params, body, idx[ii]);
}
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
}

///
Expand Down Expand Up @@ -471,7 +471,7 @@ __global__ void forallp_hip_kernel(LOOP_BODY loop_body,
{
RAJA::expt::invoke_body(f_params, body, idx[ii]);
}
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
}

template<
Expand Down Expand Up @@ -559,7 +559,7 @@ __launch_bounds__(BlockSize, 1) __global__
{
RAJA::expt::invoke_body(f_params, body, idx[ii]);
}
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
}

///
Expand Down Expand Up @@ -590,7 +590,7 @@ __global__ void forallp_hip_kernel(LOOP_BODY loop_body,
{
RAJA::expt::invoke_body(f_params, body, idx[ii]);
}
RAJA::expt::ParamMultiplexer::params_combine<EXEC_POL>(f_params);
RAJA::expt::ParamMultiplexer::params_combine(EXEC_POL{}, f_params);
}

} // namespace impl
Expand Down Expand Up @@ -699,7 +699,7 @@ RAJA_INLINE concepts::enable_if_t<
forall_impl(
resources::Hip hip_res,
::RAJA::policy::hip::
hip_exec<IterationMapping, IterationGetter, Concretizer, Async> const&,
hip_exec<IterationMapping, IterationGetter, Concretizer, Async> const& pol,
Iterable&& iter,
LoopBody&& loop_body,
ForallParam f_params)
Expand Down Expand Up @@ -751,7 +751,7 @@ forall_impl(
launch_info.res = hip_res;

{
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(f_params, launch_info);
RAJA::expt::ParamMultiplexer::params_init(pol, f_params, launch_info);

//
// Privatize the loop_body, using make_launch_body to setup reductions
Expand All @@ -768,7 +768,7 @@ forall_impl(
RAJA::hip::launch(func, dims.blocks, dims.threads, args, shmem, hip_res,
Async);

RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(f_params, launch_info);
RAJA::expt::ParamMultiplexer::params_resolve(pol, f_params, launch_info);
}

RAJA_FT_END;
Expand Down
12 changes: 6 additions & 6 deletions include/RAJA/policy/hip/launch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ __global__ void launch_new_reduce_global_fcn(BODY body_in,
RAJA::expt::invoke_body(reduce_params, body, ctx);

// Using a flatten global policy as we may use all dimensions
RAJA::expt::ParamMultiplexer::params_combine<RAJA::hip_flatten_global_xyz_direct>(
RAJA::expt::ParamMultiplexer::params_combine(RAJA::hip_flatten_global_xyz_direct{},
reduce_params);
}

Expand Down Expand Up @@ -184,7 +184,7 @@ struct LaunchExecute<
{
using EXEC_POL =
RAJA::policy::hip::hip_launch_t<async, named_usage::unspecified>;
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(launch_reducers,
RAJA::expt::ParamMultiplexer::params_init(EXEC_POL{}, launch_reducers,
launch_info);

//
Expand All @@ -201,7 +201,7 @@ struct LaunchExecute<
RAJA::hip::launch(func, gridSize, blockSize, args, shared_mem_size,
hip_res, async, kernel_name);

RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(launch_reducers,
RAJA::expt::ParamMultiplexer::params_resolve(EXEC_POL{}, launch_reducers,
launch_info);
}

Expand Down Expand Up @@ -247,7 +247,7 @@ __launch_bounds__(num_threads, 1) __global__
RAJA::expt::invoke_body(reduce_params, body, ctx);

// Using a flatten global policy as we may use all dimensions
RAJA::expt::ParamMultiplexer::params_combine<RAJA::hip_flatten_global_xyz_direct>(
RAJA::expt::ParamMultiplexer::params_combine(RAJA::hip_flatten_global_xyz_direct{},
reduce_params);
}

Expand Down Expand Up @@ -370,7 +370,7 @@ struct LaunchExecute<RAJA::policy::hip::hip_launch_t<async, nthreads>>
{
using EXEC_POL =
RAJA::policy::hip::hip_launch_t<async, named_usage::unspecified>;
RAJA::expt::ParamMultiplexer::params_init<EXEC_POL>(launch_reducers,
RAJA::expt::ParamMultiplexer::params_init(EXEC_POL{}, launch_reducers,
launch_info);

//
Expand All @@ -387,7 +387,7 @@ struct LaunchExecute<RAJA::policy::hip::hip_launch_t<async, nthreads>>
RAJA::hip::launch(func, gridSize, blockSize, args, shared_mem_size,
hip_res, async, kernel_name);

RAJA::expt::ParamMultiplexer::params_resolve<EXEC_POL>(launch_reducers,
RAJA::expt::ParamMultiplexer::params_resolve(EXEC_POL{}, launch_reducers,
launch_info);
}

Expand Down
5 changes: 4 additions & 1 deletion include/RAJA/policy/hip/params/kernel_name.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace detail
// Init
template<typename EXEC_POL>
camp::concepts::enable_if<type_traits::is_hip_policy<EXEC_POL>> param_init(
EXEC_POL const&,
KernelName& kn,
const RAJA::hip::detail::hipInfo&)
{
Expand All @@ -34,12 +35,14 @@ camp::concepts::enable_if<type_traits::is_hip_policy<EXEC_POL>> param_init(
// Combine
template<typename EXEC_POL>
RAJA_HOST_DEVICE camp::concepts::enable_if<type_traits::is_hip_policy<EXEC_POL>>
param_combine(KernelName&)
param_combine(EXEC_POL const&,
KernelName&)
{}

// Resolve
template<typename EXEC_POL>
camp::concepts::enable_if<type_traits::is_hip_policy<EXEC_POL>> param_resolve(
EXEC_POL const&,
KernelName&,
const RAJA::hip::detail::hipInfo&)
{
Expand Down
Loading

0 comments on commit 3a4a65b

Please sign in to comment.