MLIR Passes v1.0
Loading...
Searching...
No Matches
CancellationOperations.hpp
Go to the documentation of this file.
1/* This code and any associated documentation is provided "as is"
2
3Copyright 2025 Munich Quantum Software Stack Project
4
5Licensed under the Apache License, Version 2.0 with LLVM Exceptions (the
6"License"); you may not use this file except in compliance with the License.
7You may obtain a copy of the License at
8
9https://github.com/Munich-Quantum-Software-Stack/passes/blob/develop/LICENSE
10
11Unless required by applicable law or agreed to in writing, software
12distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14License for the specific language governing permissions and limitations under
15the License.
16
17SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
18*******************************************************************************
19 author Martin Letras
20 date February 2025
21 version 1.0
22*******************************************************************************
23* This source code and the accompanying materials are made available under *
24* the terms of the Apache License 2.0 which accompanies this distribution. *
25******************************************************************************/
26
36#pragma once
37
39#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
40#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
41#include "mlir/Dialect/SCF/IR/SCF.h"
42#include "mlir/Rewrite/FrozenRewritePatternSet.h"
43#include "mlir/Transforms/DialectConversion.h"
44
45#include "llvm/Support/Casting.h"
46#include "llvm/Support/raw_ostream.h"
47
48using namespace mlir;
49using namespace mqss::support::quakeDialect;
50
51namespace mqss::support::transforms {
52
68template <typename T1, typename T2>
69void patternCancellation(mlir::Operation *currentOp, int nCtrlsOp1,
70 int nTgtsOp1, int nCtrlsOp2, int nTgtsOp2) {
71 auto currentGate = dyn_cast_or_null<T2>(*currentOp);
72 if (!currentGate)
73 return;
74 // check that the current gate is compliant with the number of controls and
75 // targets
76 if (currentGate.getControls().size() != nCtrlsOp2 ||
77 currentGate.getTargets().size() != nTgtsOp2)
78 return;
79 // get the previous operation to check the swap pattern
80 auto prevOp =
81 getPreviousOperationOnTarget(currentGate, currentGate.getTargets()[0]);
82 if (!prevOp)
83 return;
84 auto previousGate = dyn_cast_or_null<T1>(prevOp);
85 if (!previousGate)
86 return;
87 // check that the previous gate is compliant with the number of controls and
88 // targets
89 if (previousGate.getControls().size() != nCtrlsOp1 ||
90 previousGate.getTargets().size() != nTgtsOp1)
91 return;
92 // check that targets and controls are the same!
93 // At the moment I am checking all controls and all targets!
94 if (currentGate.getControls().size() == previousGate.getControls().size()) {
95 std::vector<int> controlsCurr =
96 getIndicesOfValueRange(currentGate.getControls());
97 std::vector<int> controlsPrev =
98 getIndicesOfValueRange(previousGate.getControls());
99 // sort both arrays
100 std::sort(controlsCurr.begin(), controlsCurr.end(), std::greater<int>());
101 std::sort(controlsPrev.begin(), controlsPrev.end(), std::greater<int>());
102 // compare both arrays
103 if (!(std::equal(controlsCurr.begin(), controlsCurr.end(),
104 controlsPrev.begin())))
105 return;
106 } else
107 return;
108 // so far, controls are the same, now check the targets
109 if (currentGate.getTargets().size() == previousGate.getTargets().size()) {
110 std::vector<int> targetsCurr =
111 getIndicesOfValueRange(currentGate.getTargets());
112 std::vector<int> targetsPrev =
113 getIndicesOfValueRange(previousGate.getTargets());
114 // sort both arrays
115 std::sort(targetsCurr.begin(), targetsCurr.end(), std::greater<int>());
116 std::sort(targetsPrev.begin(), targetsPrev.end(), std::greater<int>());
117 // compare both arrays
118 if (!(std::equal(targetsCurr.begin(), targetsCurr.end(),
119 targetsPrev.begin())))
120 return;
121 } else
122 return;
123#ifdef DEBUG
124 llvm::outs() << "Current Operation: ";
125 currentGate->print(llvm::outs());
126 llvm::outs() << "\n";
127 llvm::outs() << "Previous Operation: ";
128 previousGate->print(llvm::outs());
129 llvm::outs() << "\n";
130#endif
131 // At this point, I should de able to remove the pattern
132 mlir::IRRewriter rewriter(currentGate->getContext());
133 // Erase the operations
134 rewriter.eraseOp(currentGate);
135 rewriter.eraseOp(previousGate);
136}
137} // namespace mqss::support::transforms
void patternCancellation(mlir::Operation *currentOp, int nCtrlsOp1, int nTgtsOp1, int nCtrlsOp2, int nTgtsOp2)
Function that removes (cancel) a pattern of two quantum operations under specific constraints.
Definition CancellationOperations.hpp:69
mlir::Operation * getPreviousOperationOnTarget(mlir::Operation *currentOp, mlir::Value targetQubit)
Function get the previous operation on a given target qubit.
std::vector< int > getIndicesOfValueRange(mlir::ValueRange array)
Function that get a vector of indices associated with a given mlir::ValueRange.