Skip to content

Commit e6cd72a

Browse files
committed
[DAPHNE-#793] Prevent merging transpose into initial SPGEMM
This commit adds a quick fix to not merge transpose operations into sparse matrix multiplication as our CPP kernels can not handle this. The temporary workaround until we support transposition uses a hard coded sparsity threshold as there is no user config available in the canonicalization. The matrix representation selection functionality would also be a suitable measure but is not implemented in a suitable way for this at the moment.
1 parent b2ba6eb commit e6cd72a

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

src/compiler/execution/DaphneIrExecutor.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,11 @@ bool DaphneIrExecutor::runPasses(mlir::ModuleOp module) {
131131
pm.addNestedPass<mlir::func::FuncOp>(mlir::daphne::createInferencePass());
132132
pm.addPass(mlir::createCanonicalizerPass());
133133

134-
if (selectMatrixRepresentations_)
135-
pm.addNestedPass<mlir::func::FuncOp>(
136-
mlir::daphne::createSelectMatrixRepresentationsPass(userConfig_));
134+
if (selectMatrixRepresentations_) {
135+
pm.addNestedPass<mlir::func::FuncOp>(mlir::daphne::createSelectMatrixRepresentationsPass(userConfig_));
136+
pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
137+
}
138+
137139
if (userConfig_.explain_select_matrix_repr)
138140
pm.addPass(mlir::daphne::createPrintIRPass(
139141
"IR after selecting matrix representations:"));

src/ir/daphneir/DaphneDialect.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,15 @@ mlir::LogicalResult mlir::daphne::MatMulOp::canonicalize(
11191119
return mlir::failure();
11201120
}
11211121

1122+
// ToDo: This check prevents merging transpose into matrix multiplication because that is not yet supported by our
1123+
// sparse kernels.
1124+
// ToDo: bring user config here for sparsity threshold or properly use MatrixRepresentation
1125+
if(auto t = rhs.getType().dyn_cast<mlir::daphne::MatrixType>()) {
1126+
auto sparsity = t.getSparsity();
1127+
if(sparsity < 0.25)
1128+
return mlir::failure();
1129+
}
1130+
11221131
#if 0
11231132
// TODO Adapt PhyOperatorSelectionPass once this code is turned on again.
11241133
if(lhsTransposeOp) {

0 commit comments

Comments
 (0)