diff --git a/cpp/daal/BUILD b/cpp/daal/BUILD index 0cbf8a50316..1b4cbe4e9ec 100644 --- a/cpp/daal/BUILD +++ b/cpp/daal/BUILD @@ -223,6 +223,7 @@ daal_algorithms( "stump", "svd", "svm", + "spectral_embedding", "weak_learner/inner", ], ) diff --git a/cpp/daal/src/algorithms/spectral_embedding/BUILD b/cpp/daal/src/algorithms/spectral_embedding/BUILD new file mode 100644 index 00000000000..5fdba550542 --- /dev/null +++ b/cpp/daal/src/algorithms/spectral_embedding/BUILD @@ -0,0 +1,11 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:daal.bzl", "daal_module") + +daal_module( + name = "kernel", + auto = True, + deps = [ + "@onedal//cpp/daal:core", + "@onedal//cpp/daal/src/algorithms/cosdistance:kernel", + ], +) diff --git a/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_fpt_cpu.cpp b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_fpt_cpu.cpp new file mode 100644 index 00000000000..684de1f8657 --- /dev/null +++ b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_fpt_cpu.cpp @@ -0,0 +1,39 @@ +/* file: spectral_embedding_default_dense_fpt_cpu.cpp */ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* +//++ +// Instantiation of CPU-specific spectral_embedding kernel implementations +//-- +*/ + +#include "spectral_embedding_kernel.h" +#include "spectral_embedding_default_dense_impl.i" + +namespace daal +{ +namespace algorithms +{ +namespace spectral_embedding +{ +namespace internal +{ +template class DAAL_EXPORT SpectralEmbeddingKernel; +} // namespace internal +} // namespace spectral_embedding +} // namespace algorithms +} // namespace daal diff --git a/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_impl.i b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_impl.i new file mode 100644 index 00000000000..5f3d7981c6a --- /dev/null +++ b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_default_dense_impl.i @@ -0,0 +1,214 @@ +/* file: spectral_embedding_default_dense_impl.i */ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* +//++ +// Implementation of cosine distance. +//-- +*/ + +#include "services/daal_defines.h" +#include "src/externals/service_math.h" +#include "src/externals/service_blas.h" +#include "src/threading/threading.h" +#include "src/algorithms/service_error_handling.h" +#include "src/data_management/service_numeric_table.h" +#include "src/algorithms/cosdistance/cosdistance_kernel.h" +#include "src/externals/service_lapack.h" +#include + +using namespace daal::internal; + +namespace daal +{ +namespace algorithms +{ +namespace spectral_embedding +{ +namespace internal +{ + +template +services::Status computeEigenvectorsInplace(size_t nFeatures, algorithmFPType * eigenvectors, algorithmFPType * eigenvalues) +{ + char jobz = 'V'; + char uplo = 'U'; + + DAAL_INT lwork = 2 * nFeatures * nFeatures + 6 * nFeatures + 1; + DAAL_INT liwork = 5 * nFeatures + 3; + DAAL_INT info; + + TArray work(lwork); + TArray iwork(liwork); + DAAL_CHECK_MALLOC(work.get() && iwork.get()); + + LapackInst::xsyevd(&jobz, &uplo, (DAAL_INT *)(&nFeatures), eigenvectors, (DAAL_INT *)(&nFeatures), eigenvalues, work.get(), + &lwork, iwork.get(), &liwork, &info); + if (info != 0) return services::Status(services::ErrorPCAFailedToComputeCorrelationEigenvalues); // CHANGE ERROR STATUS + return services::Status(); +} + +/** + * \brief Kernel for Spectral Embedding calculation + */ +template +services::Status SpectralEmbeddingKernel::compute(const NumericTable * xTable, NumericTable * embeddingTable, + NumericTable * eigenTable, const KernelParameter & par) +{ + services::Status status; + // std::cout << "inside DAAL kernel" << std::endl; + // std::cout << "Params: " << par.numberOfEmbeddings << " " << par.numberOfNeighbors << std::endl; + size_t k = par.numberOfEmbeddings; + size_t filtNum = par.numberOfNeighbors + 1; + size_t n = xTable->getNumberOfRows(); /* Number of input feature vectors */ + + SharedPtr > tmpMatrixPtr = + HomogenNumericTable::create(n, n, NumericTable::doAllocate, &status); + + DAAL_CHECK_STATUS_VAR(status); + NumericTable * covOutput = tmpMatrixPtr.get(); + NumericTable * a0 = const_cast(xTable); + NumericTable * eigenvalues = const_cast(eigenTable); + + // Compute cosine distances matrix + { + auto cosDistanceKernel = cosine_distance::internal::DistanceKernel(); + DAAL_CHECK_STATUS(status, cosDistanceKernel.compute(0, &a0, 0, &covOutput, nullptr)); + } + + WriteRows xMatrix(covOutput, 0, n); + DAAL_CHECK_BLOCK_STATUS(xMatrix); + algorithmFPType * x = xMatrix.get(); + + size_t lcnt, rcnt, cnt; + algorithmFPType L, R, M; + // Use binary search to find such d that the number of verticies having distance <= d is filtNum + const size_t binarySearchIterNum = 20; + // TODO: add parallel_for + for (size_t i = 0; i < n; ++i) + { + L = 0; // min possible cos distance + R = 2; // max possible cos distance + lcnt = 0; // number of elements with cos distance <= L + rcnt = n; // number of elements with cos distance <= R + for (size_t ij = 0; ij < binarySearchIterNum; ++ij) + { + M = (L + R) / 2; + cnt = 0; + // Calculate the number of elements in the row with value <= M + for (size_t j = 0; j < n; ++j) + { + if (x[i * n + j] <= M) + { + cnt++; + } + } + if (cnt < filtNum) + { + L = M; + lcnt = cnt; + } + else + { + R = M; + rcnt = cnt; + } + // distance threshold is found + if (rcnt == filtNum) + { + break; + } + } + // create edges for the closest neighbors + for (size_t j = 0; j < n; ++j) + { + if (x[i * n + j] <= R) + { + x[i * n + j] = 1.0; + } + else + { + x[i * n + j] = 0.0; + } + } + // fill the diagonal of matrix with zeros + x[i * n + i] = 0; + } + + // Create Laplassian matrix + for (size_t i = 0; i < n; ++i) + { + for (size_t j = 0; j < i; ++j) + { + algorithmFPType val = (x[i * n + j] + x[j * n + i]) / 2; + x[i * n + j] = -val; + x[j * n + i] = -val; + x[i * n + i] += val; + x[j * n + j] += val; + } + } + + // std::cout << "Laplacian matrix" << std::endl; + // for (int i = 0; i < n; ++i) { + // for (int j = 0; j < n; ++j) { + // std::cout << x[i * n + j] << " "; + // } + // std::cout << std::endl; + // } + // std::cout << "------" << std::endl; + + // Find the eigen vectors and eigne values of the matix + //TArray eigenvalues(n); + //DAAL_CHECK_MALLOC(eigenvalues.get()); + WriteRows eigenValuesBlock(eigenvalues, 0, n); + DAAL_CHECK_BLOCK_STATUS(eigenValuesBlock); + algorithmFPType * eigenValuesPtr = eigenValuesBlock.get(); + + status |= computeEigenvectorsInplace(n, x, eigenValuesPtr); + DAAL_CHECK_STATUS_VAR(status); + + // std::cout << "Eigen vectors: " << std::endl; + // for (int i = 0; i < n; ++i) { + // for (int j = 0; j < n; ++j) { + // std::cout << x[i * n + j] << " "; + // } + // std::cout << std::endl; + // } + + // Fill the output matrix with eigen vectors corresponding to the smallest eigen values + WriteOnlyRows embedMatrix(embeddingTable, 0, n); + DAAL_CHECK_BLOCK_STATUS(embedMatrix); + algorithmFPType * embed = embedMatrix.get(); + + for (int i = 0; i < k; ++i) + { + for (int j = 0; j < n; ++j) + { + embed[j * k + i] = x[i * n + j]; + } + } + + return status; +} + +} // namespace internal + +} // namespace spectral_embedding + +} // namespace algorithms + +} // namespace daal diff --git a/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_kernel.h b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_kernel.h new file mode 100644 index 00000000000..0248ac4b288 --- /dev/null +++ b/cpp/daal/src/algorithms/spectral_embedding/spectral_embedding_kernel.h @@ -0,0 +1,66 @@ +/* file: spectral_embedding_kernel.h */ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +/* +//++ +// Declaration of template structs that calculate SVM Training functions. +//-- +*/ + +#ifndef __SPECTRAL_EMBEDDING_KERNEL_H__ +#define __SPECTRAL_EMBEDDING_KERNEL_H__ + +#include "data_management/data/numeric_table.h" +#include "services/daal_defines.h" +#include "src/algorithms/kernel.h" + +namespace daal +{ +namespace algorithms +{ +namespace spectral_embedding +{ + +enum Method +{ + defaultDense = 0 +}; + +namespace internal +{ + +using namespace daal::data_management; +using namespace daal::services; + +struct KernelParameter : daal::algorithms::Parameter +{ + size_t numberOfEmbeddings = 1; + size_t numberOfNeighbors = 1; +}; + +template +struct SpectralEmbeddingKernel : public Kernel +{ + services::Status compute(const NumericTable * xTable, NumericTable * embeddingTable, NumericTable * eigenTable, const KernelParameter & par); +}; + +} // namespace internal +} // namespace spectral_embedding +} // namespace algorithms +} // namespace daal + +#endif diff --git a/cpp/oneapi/dal/algo/BUILD b/cpp/oneapi/dal/algo/BUILD index e93804d2e7e..a7981639dfc 100644 --- a/cpp/oneapi/dal/algo/BUILD +++ b/cpp/oneapi/dal/algo/BUILD @@ -34,6 +34,7 @@ ALGOS = [ "rbf_kernel", "sigmoid_kernel", "shortest_paths", + "spectral_embedding", "subgraph_isomorphism", "svm", "triangle_counting", diff --git a/cpp/oneapi/dal/algo/spectral_embedding.hpp b/cpp/oneapi/dal/algo/spectral_embedding.hpp new file mode 100644 index 00000000000..3606f826809 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding.hpp @@ -0,0 +1,19 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/spectral_embedding/compute.hpp" diff --git a/cpp/oneapi/dal/algo/spectral_embedding/BUILD b/cpp/oneapi/dal/algo/spectral_embedding/BUILD new file mode 100644 index 00000000000..88640f87665 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/BUILD @@ -0,0 +1,38 @@ +package(default_visibility = ["//visibility:public"]) +load("@onedal//dev/bazel:dal.bzl", + "dal_module", + "dal_test_suite", +) + +dal_module( + name = "spectral_embedding", + auto = True, + dal_deps = [ + "@onedal//cpp/oneapi/dal:core", + "@onedal//cpp/oneapi/dal/backend/primitives:common", + ], + extra_deps = [ + "@onedal//cpp/daal/src/algorithms/spectral_embedding:kernel", + ] +) + +dal_test_suite( + name = "interface_tests", + framework = "catch2", + hdrs = glob([ + "test/*.hpp", + ]), + srcs = glob([ + "test/*.cpp", + ]), + dal_deps = [ + ":spectral_embedding", + ], +) + +dal_test_suite( + name = "tests", + tests = [ + ":interface_tests", + ], +) diff --git a/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.cpp b/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.cpp new file mode 100644 index 00000000000..04dd7e92609 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.cpp @@ -0,0 +1,110 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "daal/src/algorithms/spectral_embedding/spectral_embedding_kernel.h" + +#include "oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp" +#include "oneapi/dal/backend/interop/common.hpp" +#include "oneapi/dal/backend/interop/error_converter.hpp" +#include "oneapi/dal/backend/interop/table_conversion.hpp" + +#include "oneapi/dal/table/row_accessor.hpp" +#include "oneapi/dal/detail/error_messages.hpp" +#include + +namespace oneapi::dal::spectral_embedding::backend { + +using dal::backend::context_cpu; +using descriptor_t = detail::descriptor_base; + +namespace sp_emb = daal::algorithms::spectral_embedding; + +template +using daal_sp_emb_kernel_t = + sp_emb::internal::SpectralEmbeddingKernel; + +using parameter_t = sp_emb::internal::KernelParameter; + +namespace interop = oneapi::dal::backend::interop; + +template +static compute_result call_daal_kernel(const context_cpu& ctx, + const descriptor_t& desc, + const table& data) { + const auto daal_data = interop::convert_to_daal_table(data); + + // const std::int64_t p = data.get_column_count(); + const std::int64_t n = data.get_row_count(); + std::int64_t k = desc.get_component_count(); + + // std::cout << "inside oneDAL kernel: " << n << " " << p << std::endl; + + auto result = compute_result{}.set_result_options(desc.get_result_options()); + + if (result.get_result_options().test(result_options::embedding) || + result.get_result_options().test(result_options::eigen_values)) { + daal::services::SharedPtr daal_input, daal_output, daal_eigen_vals; + array arr_output, arr_eigen_vals; + arr_output = array::empty(n * k); + arr_eigen_vals = array::empty(n); + daal_output = interop::convert_to_daal_homogen_table(arr_output, n, k); + daal_eigen_vals = interop::convert_to_daal_homogen_table(arr_eigen_vals, n, 1); + parameter_t daal_param; + + daal_param.numberOfEmbeddings = k; + if (desc.get_neighbor_count() < 0) { + daal_param.numberOfNeighbors = n - 1; + } + else { + daal_param.numberOfNeighbors = desc.get_neighbor_count(); + } + interop::status_to_exception( + interop::call_daal_kernel(ctx, + daal_data.get(), + daal_output.get(), + daal_eigen_vals.get(), + daal_param)); + if (result.get_result_options().test(result_options::embedding)) { + result.set_embedding(homogen_table::wrap(arr_output, n, k)); + } + if (result.get_result_options().test(result_options::eigen_values)) { + result.set_eigen_values(homogen_table::wrap(arr_eigen_vals, n, 1)); + } + } + + return result; +} + +template +static compute_result compute(const context_cpu& ctx, + const descriptor_t& desc, + const compute_input& input) { + return call_daal_kernel(ctx, desc, input.get_data()); +} + +template +struct compute_kernel_cpu { + compute_result operator()(const context_cpu& ctx, + const descriptor_t& desc, + const compute_input& input) const { + return compute(ctx, desc, input); + } +}; + +template struct compute_kernel_cpu; +template struct compute_kernel_cpu; + +} // namespace oneapi::dal::spectral_embedding::backend diff --git a/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp b/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp new file mode 100644 index 00000000000..855903d883c --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp @@ -0,0 +1,31 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::spectral_embedding::backend { + +template +struct compute_kernel_cpu { + compute_result operator()(const dal::backend::context_cpu& ctx, + const detail::descriptor_base& params, + const compute_input& input) const; +}; + +} // namespace oneapi::dal::spectral_embedding::backend diff --git a/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp b/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp new file mode 100644 index 00000000000..b7f4fdc63d8 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp @@ -0,0 +1,31 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::spectral_embedding::backend { + +template +struct compute_kernel_gpu { + compute_result operator()(const dal::backend::context_gpu& ctx, + const detail::descriptor_base& params, + const compute_input& input) const; +}; + +} // namespace oneapi::dal::spectral_embedding::backend diff --git a/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel_dpc.cpp b/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel_dpc.cpp new file mode 100644 index 00000000000..7cfe7dbf8b5 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel_dpc.cpp @@ -0,0 +1,36 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp" +#include "oneapi/dal/exceptions.hpp" + +namespace oneapi::dal::spectral_embedding::backend { + +template +struct compute_kernel_gpu { + compute_result operator()(const dal::backend::context_gpu& ctx, + const detail::descriptor_base& desc, + const compute_input& input) const { + // CHANGE ERROR MESSAGE + throw unimplemented( + dal::detail::error_messages::sp_emb_dense_batch_method_is_not_implemented_for_gpu()); + } +}; + +template struct compute_kernel_gpu; +template struct compute_kernel_gpu; + +} // namespace oneapi::dal::spectral_embedding::backend diff --git a/cpp/oneapi/dal/algo/spectral_embedding/common.cpp b/cpp/oneapi/dal/algo/spectral_embedding/common.cpp new file mode 100644 index 00000000000..64f279f56c5 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/common.cpp @@ -0,0 +1,94 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/spectral_embedding/common.hpp" +#include "oneapi/dal/exceptions.hpp" + +namespace oneapi::dal::spectral_embedding::detail { + +result_option_id get_embedding_id() { + return result_option_id{ result_option_id::make_by_index(0) }; +} + +result_option_id get_eigen_values_id() { + return result_option_id{ result_option_id::make_by_index(1) }; +} + +template +result_option_id get_default_result_options() { + return result_option_id{}; +} + +template <> +result_option_id get_default_result_options() { + return get_embedding_id(); +} + +namespace v1 { + +template +class descriptor_impl : public base { +public: + explicit descriptor_impl() {} + + std::int64_t component_count = 0; + std::int64_t neighbor_count = -1; + + result_option_id result_options = get_default_result_options(); +}; + +template +descriptor_base::descriptor_base() : impl_(new descriptor_impl{}) {} + +template +std::int64_t descriptor_base::get_component_count() const { + return impl_->component_count; +} + +template +std::int64_t descriptor_base::get_neighbor_count() const { + return impl_->neighbor_count; +} + +template +void descriptor_base::set_component_count_impl(std::int64_t component_count) { + impl_->component_count = component_count; +} + +template +void descriptor_base::set_neighbor_count_impl(std::int64_t neighbor_count) { + impl_->neighbor_count = neighbor_count; +} + +template +result_option_id descriptor_base::get_result_options() const { + return impl_->result_options; +} + +template +void descriptor_base::set_result_options_impl(const result_option_id& value) { + using msg = dal::detail::error_messages; + if (!bool(value)) { + throw domain_error(msg::empty_set_of_result_options()); + } + impl_->result_options = value; +} + +template class ONEDAL_EXPORT descriptor_base; + +} // namespace v1 + +} // namespace oneapi::dal::spectral_embedding::detail diff --git a/cpp/oneapi/dal/algo/spectral_embedding/common.hpp b/cpp/oneapi/dal/algo/spectral_embedding/common.hpp new file mode 100644 index 00000000000..8ceb5fda692 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/common.hpp @@ -0,0 +1,201 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/util/result_option_id.hpp" +#include "oneapi/dal/detail/common.hpp" +#include "oneapi/dal/detail/serialization.hpp" +#include "oneapi/dal/table/common.hpp" +#include "oneapi/dal/common.hpp" + +namespace oneapi::dal::spectral_embedding { + +namespace task { +namespace v1 { + +/// Tag-type that parameterizes entities that are used to compute statistics. +struct compute {}; + +/// Alias tag-type for compute task. +using by_default = compute; +} // namespace v1 + +using v1::compute; +using v1::by_default; + +} // namespace task + +namespace method { +namespace v1 { + +/// Tag-type that denotes dense_batch computational method. +struct dense_batch {}; + +/// Alias tag-type for the dense_batch computational method. +using by_default = dense_batch; + +} // namespace v1 + +using v1::dense_batch; +using v1::by_default; + +} // namespace method + +/// Represents result option flag +/// Behaves like a regular :expr`enum`. +class result_option_id : public result_option_id_base { +public: + constexpr result_option_id() = default; + constexpr explicit result_option_id(const result_option_id_base& base) + : result_option_id_base{ base } {} +}; + +namespace detail { + +ONEDAL_EXPORT result_option_id get_embedding_id(); +ONEDAL_EXPORT result_option_id get_eigen_values_id(); + +} // namespace detail + +/// Result options are used to define +/// what should algorithm return +namespace result_options { + +/// Return spectral embedding +const inline auto embedding = detail::get_embedding_id(); + +/// Return eigen values of Laplassian matrix +const inline auto eigen_values = detail::get_eigen_values_id(); + +} // namespace result_options + +namespace detail { + +namespace v1 { + +struct descriptor_tag {}; + +template +class descriptor_impl; + +template +constexpr bool is_valid_float_v = dal::detail::is_one_of_v; + +template +constexpr bool is_valid_method_v = dal::detail::is_one_of_v; + +template +constexpr bool is_valid_task_v = dal::detail::is_one_of_v; + +template +class descriptor_base : public base { + static_assert(is_valid_task_v); + +public: + using tag_t = descriptor_tag; + using float_t = float; + using method_t = method::by_default; + using task_t = Task; + + descriptor_base(); + + std::int64_t get_component_count() const; + std::int64_t get_neighbor_count() const; + result_option_id get_result_options() const; + +protected: + void set_component_count_impl(std::int64_t component_count); + void set_neighbor_count_impl(std::int64_t neighbor_count); + void set_result_options_impl(const result_option_id& value); + +private: + dal::detail::pimpl> impl_; +}; + +} // namespace v1 + +using v1::descriptor_tag; +using v1::descriptor_impl; +using v1::descriptor_base; + +using v1::is_valid_float_v; +using v1::is_valid_method_v; +using v1::is_valid_task_v; + +} // namespace detail + +namespace v1 { + +/// @tparam Float The floating-point type that the algorithm uses for +/// intermediate computations. Can be :expr:`float` or +/// :expr:`double`. +/// @tparam Method Tag-type that specifies an implementation of algorithm. Can +/// be :expr:`method::dense_batch`. +/// @tparam Task Tag-type that specifies the type of the problem to solve. Can +/// be :expr:`task::compute`. + +template +class descriptor : public detail::descriptor_base { + static_assert(detail::is_valid_float_v); + static_assert(detail::is_valid_method_v); + static_assert(detail::is_valid_task_v); + using base_t = detail::descriptor_base; + +public: + using float_t = Float; + using method_t = Method; + using task_t = Task; + + /// Creates a new instance of the class with the default property values. + explicit descriptor() : base_t() {} + + std::int64_t get_component_count() const { + return base_t::get_component_count(); + } + + std::int64_t get_neighbor_count() const { + return base_t::get_neighbor_count(); + } + + auto& set_component_count(std::int64_t component_count) { + base_t::set_component_count_impl(component_count); + return *this; + } + + auto& set_neighbor_count(std::int64_t neighbor_count) { + base_t::set_neighbor_count_impl(neighbor_count); + return *this; + } + + /// Choose which results should be computed and returned. + result_option_id get_result_options() const { + return base_t::get_result_options(); + } + + auto& set_result_options(const result_option_id& value) { + base_t::set_result_options_impl(value); + return *this; + } +}; + +} // namespace v1 + +using v1::descriptor; + +} // namespace oneapi::dal::spectral_embedding diff --git a/cpp/oneapi/dal/algo/spectral_embedding/compute.hpp b/cpp/oneapi/dal/algo/spectral_embedding/compute.hpp new file mode 100644 index 00000000000..ceb49cc42fd --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/compute.hpp @@ -0,0 +1,31 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp" +#include "oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp" +#include "oneapi/dal/compute.hpp" + +namespace oneapi::dal::detail { +namespace v1 { + +template +struct compute_ops + : dal::spectral_embedding::detail::compute_ops {}; + +} // namespace v1 +} // namespace oneapi::dal::detail diff --git a/cpp/oneapi/dal/algo/spectral_embedding/compute_types.cpp b/cpp/oneapi/dal/algo/spectral_embedding/compute_types.cpp new file mode 100644 index 00000000000..842ffd1134e --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/compute_types.cpp @@ -0,0 +1,109 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp" +#include "oneapi/dal/detail/common.hpp" +#include "oneapi/dal/exceptions.hpp" + +namespace oneapi::dal::spectral_embedding { + +template +class detail::v1::compute_input_impl : public base { +public: + compute_input_impl(const table& data) : data(data) {} + table data; +}; + +template +class detail::v1::compute_result_impl : public base { +public: + table embedding; + table eigen_values; + result_option_id options; +}; + +using detail::v1::compute_input_impl; +using detail::v1::compute_result_impl; + +namespace v1 { + +template +compute_input::compute_input(const table& data) : impl_(new compute_input_impl(data)) {} + +template +const table& compute_input::get_data() const { + return impl_->data; +} + +template +void compute_input::set_data_impl(const table& value) { + impl_->data = value; +} + +template +compute_result::compute_result() : impl_(new compute_result_impl{}) {} + +template +const table& compute_result::get_embedding() const { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::embedding)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + return impl_->embedding; +} + +template +void compute_result::set_embedding_impl(const table& value) { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::embedding)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + impl_->embedding = value; +} + +template +const table& compute_result::get_eigen_values() const { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::eigen_values)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + return impl_->eigen_values; +} + +template +void compute_result::set_eigen_values_impl(const table& value) { + using msg = dal::detail::error_messages; + if (!get_result_options().test(result_options::eigen_values)) { + throw domain_error(msg::this_result_is_not_enabled_via_result_options()); + } + impl_->eigen_values = value; +} + +template +const result_option_id& compute_result::get_result_options() const { + return impl_->options; +} + +template +void compute_result::set_result_options_impl(const result_option_id& value) { + impl_->options = value; +} + +template class ONEDAL_EXPORT compute_input; +template class ONEDAL_EXPORT compute_result; + +} // namespace v1 +} // namespace oneapi::dal::spectral_embedding diff --git a/cpp/oneapi/dal/algo/spectral_embedding/compute_types.hpp b/cpp/oneapi/dal/algo/spectral_embedding/compute_types.hpp new file mode 100644 index 00000000000..010b0b84760 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/compute_types.hpp @@ -0,0 +1,124 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/spectral_embedding/common.hpp" + +namespace oneapi::dal::spectral_embedding { + +namespace detail { +namespace v1 { +template +class compute_input_impl; + +template +class compute_result_impl; + +} // namespace v1 + +using v1::compute_input_impl; +using v1::compute_result_impl; + +} // namespace detail + +namespace v1 { + +/// @tparam Task Tag-type that specifies the type of the problem to solve. Can +/// be :expr:`task::compute`. +template +class compute_input : public base { + static_assert(detail::is_valid_task_v); + +public: + using task_t = Task; + + /// Creates a new instance of the class with the given :literal:`data` + compute_input(const table& data); + + /// An $n \\times p$ table with the training data, where each row stores one + /// feature vector. + /// @remark default = table{} + const table& get_data() const; + + auto& set_data(const table& value) { + set_data_impl(value); + return *this; + } + +protected: + void set_data_impl(const table& value); + +private: + dal::detail::pimpl> impl_; +}; + +/// @tparam Task Tag-type that specifies the type of the problem to solve. Can +/// be :expr:`task::compute`. +template +class compute_result : public base { + static_assert(detail::is_valid_task_v); + +public: + using task_t = Task; + + /// Creates a new instance of the class with the default property values. + compute_result(); + + /// The matrix of size $n \\times k$ with + /// spectral embeddings. + /// @remark default = table{} + const table& get_embedding() const; + + auto& set_embedding(const table& value) { + set_embedding_impl(value); + return *this; + } + + /// The matrix of size $n \\times 1$ with + /// eigen values of Laplassian matrix. + /// @remark default = table{} + const table& get_eigen_values() const; + + auto& set_eigen_values(const table& value) { + set_eigen_values_impl(value); + return *this; + } + + /// Result options that indicates availability of the properties + /// @remark default = default_result_options + const result_option_id& get_result_options() const; + + auto& set_result_options(const result_option_id& value) { + set_result_options_impl(value); + return *this; + } + +protected: + void set_embedding_impl(const table&); + void set_eigen_values_impl(const table&); + void set_result_options_impl(const result_option_id&); + +private: + dal::detail::pimpl> impl_; +}; + +} // namespace v1 + +using v1::compute_input; +using v1::compute_result; + +} // namespace oneapi::dal::spectral_embedding diff --git a/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.cpp b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.cpp new file mode 100644 index 00000000000..3cbec045d3b --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.cpp @@ -0,0 +1,42 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp" +#include "oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::spectral_embedding::detail { +namespace v1 { + +template +struct compute_ops_dispatcher { + compute_result operator()(const Policy& policy, + const descriptor_base& desc, + const compute_input& input) const { + using kernel_dispatcher_t = dal::backend::kernel_dispatcher< // + KERNEL_SINGLE_NODE_CPU(backend::compute_kernel_cpu)>; + return kernel_dispatcher_t()(policy, desc, input); + } +}; + +#define INSTANTIATE(F, M, T) \ + template struct ONEDAL_EXPORT compute_ops_dispatcher; + +INSTANTIATE(float, method::dense_batch, task::compute) +INSTANTIATE(double, method::dense_batch, task::compute) + +} // namespace v1 +} // namespace oneapi::dal::spectral_embedding::detail diff --git a/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp new file mode 100644 index 00000000000..8b988c2fb7c --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/spectral_embedding/compute_types.hpp" +#include "oneapi/dal/detail/error_messages.hpp" + +namespace oneapi::dal::spectral_embedding::detail { +namespace v1 { + +template +struct compute_ops_dispatcher { + compute_result operator()(const Context&, + const descriptor_base&, + const compute_input&) const; +}; + +template +struct compute_ops { + using float_t = typename Descriptor::float_t; + using method_t = typename Descriptor::method_t; + using task_t = typename Descriptor::task_t; + using input_t = compute_input; + using result_t = compute_result; + using descriptor_base_t = descriptor_base; + + void check_preconditions(const Descriptor& params, const input_t& input) const { + using msg = dal::detail::error_messages; + + if (!input.get_data().has_data()) { + throw domain_error(msg::input_data_is_empty()); + } + } + + void check_postconditions(const Descriptor& params, + const input_t& input, + const result_t& result) const { + using msg = dal::detail::error_messages; + std::int64_t n = input.get_data().get_row_count(); + if (result.get_result_options().test(result_options::embedding)) { + if (!result.get_embedding().has_data()) { + throw domain_error(msg::value_is_not_provided()); // TODO: update error message! + } + if (result.get_embedding().get_row_count() != n) { + throw domain_error(msg::incorrect_output_table_size()); + } + } + } + + template + auto operator()(const Context& ctx, const Descriptor& desc, const input_t& input) const { + check_preconditions(desc, input); + const auto result = + compute_ops_dispatcher()(ctx, desc, input); + check_postconditions(desc, input, result); + return result; + } +}; + +} // namespace v1 + +using v1::compute_ops; + +} // namespace oneapi::dal::spectral_embedding::detail diff --git a/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops_dpc.cpp b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops_dpc.cpp new file mode 100644 index 00000000000..420fa869b69 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/detail/compute_ops_dpc.cpp @@ -0,0 +1,45 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/algo/spectral_embedding/backend/cpu/compute_kernel.hpp" +#include "oneapi/dal/algo/spectral_embedding/backend/gpu/compute_kernel.hpp" +#include "oneapi/dal/algo/spectral_embedding/detail/compute_ops.hpp" +#include "oneapi/dal/backend/dispatcher.hpp" + +namespace oneapi::dal::spectral_embedding::detail { +namespace v1 { + +template +struct compute_ops_dispatcher { + compute_result operator()(const Policy& policy, + const descriptor_base& params, + const compute_input& input) const { + using kernel_dispatcher_t = dal::backend::kernel_dispatcher< + KERNEL_SINGLE_NODE_CPU(backend::compute_kernel_cpu), + KERNEL_SINGLE_NODE_GPU(backend::compute_kernel_gpu)>; + return kernel_dispatcher_t{}(policy, params, input); + } +}; + +#define INSTANTIATE(F, M, T) \ + template struct ONEDAL_EXPORT \ + compute_ops_dispatcher; + +INSTANTIATE(float, method::dense_batch, task::compute) +INSTANTIATE(double, method::dense_batch, task::compute) + +} // namespace v1 +} // namespace oneapi::dal::spectral_embedding::detail diff --git a/cpp/oneapi/dal/algo/spectral_embedding/test/batch.cpp b/cpp/oneapi/dal/algo/spectral_embedding/test/batch.cpp new file mode 100644 index 00000000000..82353d942e0 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/test/batch.cpp @@ -0,0 +1,55 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dal/table/homogen.hpp" +#include "oneapi/dal/table/detail/table_builder.hpp" +#include "oneapi/dal/table/row_accessor.hpp" + +#include "oneapi/dal/test/engine/common.hpp" +#include "oneapi/dal/test/engine/fixtures.hpp" +#include "oneapi/dal/test/engine/dataframe.hpp" +#include "oneapi/dal/test/engine/math.hpp" +#include "oneapi/dal/algo/spectral_embedding/test/fixture.hpp" + +namespace oneapi::dal::spectral_embedding::test { + +namespace te = dal::test::engine; +namespace de = dal::detail; +namespace sp_emb = oneapi::dal::spectral_embedding; + +template +class spectral_embedding_batch_test + : public spectral_embedding_test> { +public: + using base_t = spectral_embedding_test>; + + void gen_dimensions() { + this->n_ = GENERATE(8); + this->p_ = GENERATE(3); + } +}; + +TEMPLATE_LIST_TEST_M(spectral_embedding_batch_test, + "spectral_embedding gold test", + "[spectral embedding][integration][cpu]", + spectral_embedding_types) { + SKIP_IF(this->not_float64_friendly()); + SKIP_IF(this->get_policy().is_gpu()); + + this->test_gold_input(); +} + +} // namespace oneapi::dal::spectral_embedding::test diff --git a/cpp/oneapi/dal/algo/spectral_embedding/test/fixture.hpp b/cpp/oneapi/dal/algo/spectral_embedding/test/fixture.hpp new file mode 100644 index 00000000000..54e490c4515 --- /dev/null +++ b/cpp/oneapi/dal/algo/spectral_embedding/test/fixture.hpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#pragma once + +#include "oneapi/dal/algo/spectral_embedding/compute.hpp" + +#include "oneapi/dal/test/engine/common.hpp" +#include "oneapi/dal/test/engine/fixtures.hpp" +#include "oneapi/dal/test/engine/dataframe.hpp" +#include "oneapi/dal/test/engine/math.hpp" +#include "oneapi/dal/detail/debug.hpp" +#include + +namespace oneapi::dal::spectral_embedding::test { + +namespace te = dal::test::engine; +namespace de = dal::detail; +namespace sp_emb = oneapi::dal::spectral_embedding; + +using dal::detail::operator<<; + +template +class spectral_embedding_test : public te::crtp_algo_fixture { +public: + using Float = std::tuple_element_t<0, TestType>; + using Method = std::tuple_element_t<1, TestType>; + using input_t = sp_emb::compute_input<>; + using result_t = sp_emb::compute_result<>; + using descriptor_t = sp_emb::descriptor; + + auto get_descriptor(std::int64_t component_count, + std::int64_t neighbor_count, + sp_emb::result_option_id compute_mode) const { + return descriptor_t() + .set_component_count(component_count) + .set_neighbor_count(neighbor_count) + .set_result_options(compute_mode); + } + + void gen_input() { + std::mt19937 rnd(2007 + n_ + p_ + n_ * p_); + const te::dataframe data_df = + GENERATE_DATAFRAME(te::dataframe_builder{ n_, p_ }.fill_normal(-0.5, 0.5, 7777)); + data_ = data_df.get_table(this->get_policy(), this->get_homogen_table_id()); + } + + void test_gold_input(Float tol = 1e-5) { + constexpr std::int64_t n = 8; + constexpr std::int64_t p = 4; + constexpr std::int64_t neighbor_count = 5; + constexpr std::int64_t component_count = 4; + + constexpr Float data[n * p] = { 0.49671415, -0.1382643, 0.64768854, 1.52302986, + -0.23415337, -0.23413696, 1.57921282, 0.76743473, + -0.46947439, 0.54256004, -0.46341769, -0.46572975, + 0.24196227, -1.91328024, -1.72491783, -0.56228753, + -1.01283112, 0.31424733, -0.90802408, -1.4123037, + 1.46564877, -0.2257763, 0.0675282, -1.42474819, + -0.54438272, 0.11092259, -1.15099358, 0.37569802, + -0.60063869, -0.29169375, -0.60170661, 1.85227818 }; + + constexpr Float gth_embedding[n * component_count] = { + -0.353553391, 0.442842965, 0.190005876, 0.705830111, -0.353553391, 0.604392576, + -0.247517958, -0.595235173, -0.353553391, -0.391745507, 0.0443633719, -0.150208165, + -0.353553391, -0.142548722, 0.0125222995, -0.0318482841, -0.353553391, -0.499390711, + -0.20194266, -0.000639679859, -0.353553391, 0.00809834849, -0.683462258, 0.273398265, + -0.353553391, -0.0977843445, 0.449358299, 0.0195905172, -0.353553391, 0.0761353959, + 0.436673029, -0.220887591 + }; + + constexpr Float gth_eigen_vals[n] = { 0, 3.32674524, 4.70361338, 5.26372220, + 5.69343808, 6.63074948, 6.80173994, 7.57999167 }; + + auto desc = get_descriptor( + component_count, + neighbor_count, + sp_emb::result_options::embedding | sp_emb::result_options::eigen_values); + + table data_ = homogen_table::wrap(data, n, p); + + INFO("run compute"); + auto compute_result = this->compute(desc, data_); + auto embedding = compute_result.get_embedding(); + // std::cout << "Output" << std::endl; + // std::cout << embedding << std::endl; + + array emb_arr = row_accessor(embedding).pull({ 0, -1 }); + for (int j = 0; j < component_count; ++j) { + Float diff = 0, diff_rev = 0; + for (int i = 0; i < n; ++i) { + Float val = emb_arr[i * component_count + j]; + Float gth_val = gth_embedding[i * component_count + j]; + diff = std::max(diff, std::abs(val - gth_val)); + diff_rev = std::max(diff_rev, std::abs(val + gth_val)); + } + REQUIRE((diff < tol || diff_rev < tol)); + } + + auto eigen_values = compute_result.get_eigen_values(); + // std::cout << "Eigen values:" << std::endl; + // std::cout << eigen_values << std::endl; + + array eig_val_arr = row_accessor(eigen_values).pull({ 0, -1 }); + for (int i = 0; i < n; ++i) { + REQUIRE(std::abs(eig_val_arr[i] - gth_eigen_vals[i]) < tol); + } + } + +protected: + std::int64_t n_; + std::int64_t p_; + table data_; +}; + +using spectral_embedding_types = COMBINE_TYPES((float, double), (sp_emb::method::dense_batch)); + +} // namespace oneapi::dal::spectral_embedding::test diff --git a/cpp/oneapi/dal/detail/error_messages.cpp b/cpp/oneapi/dal/detail/error_messages.cpp index 20ce68e0ef2..7fbdab3bb2a 100644 --- a/cpp/oneapi/dal/detail/error_messages.cpp +++ b/cpp/oneapi/dal/detail/error_messages.cpp @@ -262,6 +262,10 @@ MSG(nothing_to_compute, "Invalid combination of optional results: nothing to com MSG(distances_are_uninitialized, "Distances are not set as an optional result") MSG(predecessors_are_uninitialized, "Predecessors are not set as an optional result") +/*Spectral Embedding*/ +MSG(sp_emb_dense_batch_method_is_not_implemented_for_gpu, + "Spectral embedding algorithm is not implemented on GPU") + /* SVM */ MSG(c_leq_zero, "C is lower than or equal to zero") MSG(cache_size_lt_zero, "Cache size is lower than zero") diff --git a/cpp/oneapi/dal/detail/error_messages.hpp b/cpp/oneapi/dal/detail/error_messages.hpp index dbb9c8ba6ec..60e60a925f7 100644 --- a/cpp/oneapi/dal/detail/error_messages.hpp +++ b/cpp/oneapi/dal/detail/error_messages.hpp @@ -291,6 +291,9 @@ class ONEDAL_EXPORT error_messages { MSG(distances_are_uninitialized); MSG(predecessors_are_uninitialized); + /*Spectral Embedding*/ + MSG(sp_emb_dense_batch_method_is_not_implemented_for_gpu); + /* SVM */ MSG(c_leq_zero); MSG(cache_size_lt_zero); diff --git a/examples/oneapi/cpp/BUILD b/examples/oneapi/cpp/BUILD index a90cf71b7df..dc99e223346 100644 --- a/examples/oneapi/cpp/BUILD +++ b/examples/oneapi/cpp/BUILD @@ -76,6 +76,19 @@ dal_example_suite( ], ) +dal_example_suite( + name = "spectral_clustering", + compile_as = [ "c++" ], + srcs = glob(["source/spectral_clustering/*.cpp"]), + dal_deps = [ + "@onedal//cpp/oneapi/dal/algo:kmeans", + "@onedal//cpp/oneapi/dal/algo:kmeans_init", + "@onedal//cpp/oneapi/dal/algo:spectral_embedding", + ], + data = _DATA_DEPS, + extra_deps = _TEST_DEPS, +) + dal_algo_example_suite( algos = [ "basic_statistics", diff --git a/examples/oneapi/cpp/source/spectral_clustering/spectral_clustering_pipeline.cpp b/examples/oneapi/cpp/source/spectral_clustering/spectral_clustering_pipeline.cpp new file mode 100644 index 00000000000..d14a29776cc --- /dev/null +++ b/examples/oneapi/cpp/source/spectral_clustering/spectral_clustering_pipeline.cpp @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright contributors to the oneDAL project +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "example_util/utils.hpp" +#include "oneapi/dal/algo/kmeans.hpp" +#include "oneapi/dal/algo/kmeans_init.hpp" +#include "oneapi/dal/algo/spectral_embedding.hpp" +#include "oneapi/dal/io/csv.hpp" +#include +#include + +namespace dal = oneapi::dal; + +int main(int argc, char const *argv[]) { + double p = 0.01; // prunning parameter + std::int64_t num_spks = 8; // dimension of spectral embeddings + + std::int64_t cluster_count = num_spks; // number of clusters + std::int64_t max_iteration_count = 300; // max iterations number for K-Means + std::int64_t n_init = 20; // number of K-means++ iterations + double accuracy_threshold = 1e-4; // threshold for early stop in K-Means + + const auto voice_data_file_name = + get_data_path("covcormoments_dense.csv"); // Dataset with original features + + std::cout << voice_data_file_name << std::endl; + + const auto x_train = dal::read(dal::csv::data_source{ voice_data_file_name }); + + std::int64_t m = x_train.get_row_count(); + std::int64_t n_neighbors; + + if (m < 1000) { + n_neighbors = std::min((std::int64_t)10, m - 2); + } + else { + n_neighbors = (std::int64_t)(p * m); + } + + const auto spectral_embedding_desc = + dal::spectral_embedding::descriptor<>() + .set_neighbor_count(n_neighbors) + .set_component_count(num_spks) + .set_result_options(dal::spectral_embedding::result_options::embedding | + dal::spectral_embedding::result_options::eigen_values); + + const auto spectral_embedding_result = dal::compute(spectral_embedding_desc, x_train); + + const auto spectral_embeddings = + spectral_embedding_result.get_embedding(); // Matrix with spectral embeddings m * num_spks + + std::cout << "Spectral embeddings:\n" << spectral_embeddings << std::endl; + + std::cout << "Eigen values:\n" << spectral_embedding_result.get_eigen_values() << std::endl; + + const auto kmeans_init_desc = + dal::kmeans_init::descriptor() + .set_cluster_count(cluster_count) + .set_local_trials_count(n_init); + + const auto kmeans_init_result = dal::compute(kmeans_init_desc, spectral_embeddings); + + const auto initial_centroids = kmeans_init_result.get_centroids(); + + const auto kmeans_desc = dal::kmeans::descriptor<>() + .set_cluster_count(cluster_count) + .set_max_iteration_count(max_iteration_count) + .set_accuracy_threshold(accuracy_threshold); + + const auto spectral_clustering_result = + dal::train(kmeans_desc, spectral_embeddings, initial_centroids); + + std::cout << "Responses:\n" << spectral_clustering_result.get_responses() << std::endl; + std::cout << "Centroids:\n" + << spectral_clustering_result.get_model().get_centroids() << std::endl; + return 0; +} diff --git a/makefile.lst b/makefile.lst index 92dc52ff521..ccc82a8a579 100755 --- a/makefile.lst +++ b/makefile.lst @@ -24,7 +24,7 @@ CORE.ALGORITHMS.CUSTOM.AVAILABLE := low_order_moments quantiles covariance cosdi dtrees/gbt dtrees/forest linear_regression ridge_regression naivebayes stump adaboost brownboost \ logitboost svm multiclassclassifier k_nearest_neighbors logistic_regression implicit_als \ coordinate_descent jaccard triangle_counting shortest_paths subgraph_isomorphism connected_components \ - louvain tsne + louvain tsne spectral_embedding classifier += classifier/inner low_order_moments += @@ -68,6 +68,7 @@ implicit_als += engines distributions engines += engines/mt19937 engines/mcg59 engines/mt2203 distributions += distributions/bernoulli distributions/normal distributions/uniform tsne += +spectral_embedding += cosdistance CORE.ALGORITHMS.FULL := \ adaboost \ @@ -141,7 +142,8 @@ CORE.ALGORITHMS.FULL := \ svd \ svm \ weak_learner/inner \ - tsne + tsne \ + spectral_embedding CORE.ALGORITHMS := $(if $(CORE.ALGORITHMS.CUSTOM), $(CORE.ALGORITHMS.CUSTOM), $(CORE.ALGORITHMS.FULL)) CORE.ALGORITHMS := $(sort $(foreach alg,$(CORE.ALGORITHMS),$(foreach alg1,$($(alg)),$(foreach alg2,$($(alg1)),$($(alg2)) $(alg2)) $(alg1)) $(alg))) @@ -216,6 +218,7 @@ ONEAPI.ALGOS.pca := CORE.pca ONEAPI.ALGOS.polynomial_kernel := CORE.kernel_function ONEAPI.ALGOS.sigmoid_kernel := CORE.kernel_function ONEAPI.ALGOS.rbf_kernel := CORE.kernel_function +ONEAPI.ALGOS.spectral_embedding := CORE.spectral_embedding ONEAPI.ALGOS.svm := CORE.svm # List of algorithms in oneAPI part @@ -244,6 +247,7 @@ ONEAPI.ALGOS := \ polynomial_kernel \ sigmoid_kernel \ rbf_kernel \ + spectral_embedding \ svm \ jaccard \ triangle_counting \