automatic code formatting
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-13 21:51:19 +02:00
parent 55eda487dc
commit 8d95c604a6
13 changed files with 100 additions and 111 deletions
+1 -2
View File
@@ -1,7 +1,6 @@
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypeInterfaces.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
using namespace mlir; using namespace mlir;
+1 -2
View File
@@ -1,8 +1,7 @@
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
#include "llvm/Support/Format.h" #include "llvm/Support/Format.h"
#include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp" #include "src/Accelerators/PIM/Common/Support/FileSystemUtils.hpp"
#include "src/Accelerators/PIM/Common/Support/ReportUtils.hpp"
namespace onnx_mlir { namespace onnx_mlir {
+1 -2
View File
@@ -1,10 +1,9 @@
#pragma once #pragma once
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include <cstdint>
#include <fstream> #include <fstream>
#include <limits> #include <limits>
#include <string> #include <string>
+36 -43
View File
@@ -70,9 +70,7 @@ inline void writeUint32LE(llvm::raw_ostream& os, uint32_t value) {
os.write(bytes.data(), bytes.size()); os.write(bytes.data(), bytes.size());
} }
inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { inline void writeInt32LE(llvm::raw_ostream& os, int32_t value) { writeUint32LE(os, static_cast<uint32_t>(value)); }
writeUint32LE(os, static_cast<uint32_t>(value));
}
inline void writeHeader(llvm::raw_ostream& os) { inline void writeHeader(llvm::raw_ostream& os) {
os.write(kMagic, sizeof(kMagic)); os.write(kMagic, sizeof(kMagic));
@@ -186,39 +184,39 @@ inline Opcode opcodeFromString(llvm::StringRef opName) {
inline llvm::StringRef opcodeToString(Opcode opcode) { inline llvm::StringRef opcodeToString(Opcode opcode) {
switch (opcode) { switch (opcode) {
case Opcode::nop: return "nop"; case Opcode::nop: return "nop";
case Opcode::sldi: return "sldi"; case Opcode::sldi: return "sldi";
case Opcode::sld: return "sld"; case Opcode::sld: return "sld";
case Opcode::sadd: return "sadd"; case Opcode::sadd: return "sadd";
case Opcode::ssub: return "ssub"; case Opcode::ssub: return "ssub";
case Opcode::smul: return "smul"; case Opcode::smul: return "smul";
case Opcode::saddi: return "saddi"; case Opcode::saddi: return "saddi";
case Opcode::smuli: return "smuli"; case Opcode::smuli: return "smuli";
case Opcode::setbw: return "setbw"; case Opcode::setbw: return "setbw";
case Opcode::mvmul: return "mvmul"; case Opcode::mvmul: return "mvmul";
case Opcode::vvadd: return "vvadd"; case Opcode::vvadd: return "vvadd";
case Opcode::vvsub: return "vvsub"; case Opcode::vvsub: return "vvsub";
case Opcode::vvmul: return "vvmul"; case Opcode::vvmul: return "vvmul";
case Opcode::vvdmul: return "vvdmul"; case Opcode::vvdmul: return "vvdmul";
case Opcode::vvmax: return "vvmax"; case Opcode::vvmax: return "vvmax";
case Opcode::vvsll: return "vvsll"; case Opcode::vvsll: return "vvsll";
case Opcode::vvsra: return "vvsra"; case Opcode::vvsra: return "vvsra";
case Opcode::vavg: return "vavg"; case Opcode::vavg: return "vavg";
case Opcode::vrelu: return "vrelu"; case Opcode::vrelu: return "vrelu";
case Opcode::vtanh: return "vtanh"; case Opcode::vtanh: return "vtanh";
case Opcode::vsigm: return "vsigm"; case Opcode::vsigm: return "vsigm";
case Opcode::vsoftmax: return "vsoftmax"; case Opcode::vsoftmax: return "vsoftmax";
case Opcode::vmv: return "vmv"; case Opcode::vmv: return "vmv";
case Opcode::vrsu: return "vrsu"; case Opcode::vrsu: return "vrsu";
case Opcode::vrsl: return "vrsl"; case Opcode::vrsl: return "vrsl";
case Opcode::ld: return "ld"; case Opcode::ld: return "ld";
case Opcode::st: return "st"; case Opcode::st: return "st";
case Opcode::lldi: return "lldi"; case Opcode::lldi: return "lldi";
case Opcode::lmv: return "lmv"; case Opcode::lmv: return "lmv";
case Opcode::send: return "send"; case Opcode::send: return "send";
case Opcode::recv: return "recv"; case Opcode::recv: return "recv";
case Opcode::wait: return "wait"; case Opcode::wait: return "wait";
case Opcode::sync: return "sync"; case Opcode::sync: return "sync";
} }
llvm_unreachable("Unsupported PIM binary opcode"); llvm_unreachable("Unsupported PIM binary opcode");
} }
@@ -235,9 +233,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
case Opcode::sldi: case Opcode::sldi:
case Opcode::saddi: case Opcode::saddi:
case Opcode::smuli: case Opcode::smuli:
case Opcode::lldi: case Opcode::lldi: record.r2OrImm = getOptionalInt(instruction, "imm"); break;
record.r2OrImm = getOptionalInt(instruction, "imm");
break;
case Opcode::mvmul: case Opcode::mvmul:
record.r2OrImm = getOptionalInt(instruction, "mbiw"); record.r2OrImm = getOptionalInt(instruction, "mbiw");
record.generic1 = getOptionalInt(instruction, "relu"); record.generic1 = getOptionalInt(instruction, "relu");
@@ -252,9 +248,7 @@ inline InstructionRecord makeInstructionRecord(const llvm::json::Object& instruc
record.r2OrImm = getOptionalInt(instruction, "core"); record.r2OrImm = getOptionalInt(instruction, "core");
record.generic3 = getOptionalInt(instruction, "size"); record.generic3 = getOptionalInt(instruction, "size");
break; break;
default: default: record.r2OrImm = getOptionalInt(instruction, "rs2"); break;
record.r2OrImm = getOptionalInt(instruction, "rs2");
break;
} }
if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) { if (record.opcode != Opcode::mvmul && record.opcode != Opcode::setbw) {
@@ -371,8 +365,7 @@ inline llvm::json::Object makeInstructionJson(const InstructionRecord& record) {
break; break;
case Opcode::wait: case Opcode::wait:
case Opcode::sync: case Opcode::sync:
case Opcode::nop: case Opcode::nop: break;
break;
} }
return instruction; return instruction;
+1 -1
View File
@@ -367,7 +367,7 @@ void PimCodeGen::emitMemCopyOp(StringRef opName,
instruction.generic1 = 0; instruction.generic1 = 0;
instruction.generic2 = 0; instruction.generic2 = 0;
instruction.generic3 = static_cast<int32_t>(size); instruction.generic3 = static_cast<int32_t>(size);
(void)sizeFieldName; (void) sizeFieldName;
emitInstruction(instruction); emitInstruction(instruction);
} }
@@ -1,5 +1,4 @@
#include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp" #include "src/Accelerators/PIM/Conversion/SpatialToPim/TensorPackingPatterns.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
using namespace mlir; using namespace mlir;
@@ -75,16 +74,14 @@ struct PackSpatialConcatInputsPattern final : OpRewritePattern<spatial::SpatConc
return failure(); return failure();
auto outputType = cast<ShapedType>(concatOp.getOutput().getType()); auto outputType = cast<ShapedType>(concatOp.getOutput().getType());
auto newConcat = pim::PimConcatOp::create(rewriter, auto newConcat = pim::PimConcatOp::create(
concatOp.getLoc(), rewriter,
concatOp.getOutput().getType(), concatOp.getLoc(),
concatOp.getAxisAttr(), concatOp.getOutput().getType(),
ValueRange(packedInputs), concatOp.getAxisAttr(),
tensor::EmptyOp::create(rewriter, ValueRange(packedInputs),
concatOp.getLoc(), tensor::EmptyOp::create(rewriter, concatOp.getLoc(), outputType.getShape(), outputType.getElementType())
outputType.getShape(), .getResult());
outputType.getElementType())
.getResult());
rewriter.replaceOp(concatOp, newConcat.getOutput()); rewriter.replaceOp(concatOp, newConcat.getOutput());
return success(); return success();
} }
@@ -1,7 +1,7 @@
#pragma once #pragma once
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
@@ -1,15 +1,15 @@
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/STLExtras.h"
#include <limits> #include <limits>
#include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp"
using namespace mlir; using namespace mlir;
namespace onnx_mlir { namespace onnx_mlir {
@@ -29,9 +29,8 @@ static uint64_t getTypeSizeBytes(MemRefType type) {
return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8); return static_cast<uint64_t>(type.getNumElements() * type.getElementTypeBitWidth() / 8);
} }
static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp, static FailureOr<uint64_t>
Block& body, getLastUseInstruction(memref::AllocOp allocOp, Block& body, const DenseMap<Operation*, uint64_t>& opOrder) {
const DenseMap<Operation*, uint64_t>& opOrder) {
uint64_t endInstruction = opOrder.lookup(allocOp); uint64_t endInstruction = opOrder.lookup(allocOp);
SmallPtrSet<Operation*, 16> visited; SmallPtrSet<Operation*, 16> visited;
SmallVector<Value> pendingValues; SmallVector<Value> pendingValues;
@@ -45,10 +44,9 @@ static FailureOr<uint64_t> getLastUseInstruction(memref::AllocOp allocOp,
if (!visited.insert(user).second) if (!visited.insert(user).second)
continue; continue;
if (isSupportedAliasOp(user)) { if (isSupportedAliasOp(user))
for (Value result : user->getResults()) for (Value result : user->getResults())
pendingValues.push_back(result); pendingValues.push_back(result);
}
if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) { if (auto dpsOp = dyn_cast<DestinationStyleOpInterface>(user)) {
for (OpResult result : user->getResults()) { for (OpResult result : user->getResults()) {
@@ -2,7 +2,6 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/PatternMatch.h" #include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@@ -45,9 +45,7 @@ struct CoalescingReportEntry {
CoalescingReportRow row; CoalescingReportRow row;
}; };
static std::string formatMemory(uint64_t bytes) { static std::string formatMemory(uint64_t bytes) { return formatReportMemory(bytes); }
return formatReportMemory(bytes);
}
static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) { static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName); auto coreIdsAttr = coreBatchOp->getAttrOfType<DenseI32ArrayAttr>(onnx_mlir::kCoreIdsAttrName);
@@ -58,9 +56,10 @@ static SmallVector<int32_t> getBatchCoreIds(pim::PimCoreBatchOp coreBatchOp) {
static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) { static void printReportRow(raw_ostream& os, const CoalescingReportRow& row) {
llvm::SmallVector<ReportField, 4> fields = { llvm::SmallVector<ReportField, 4> fields = {
{"Number of candidates", std::to_string(row.numCandidates)}, {"Number of candidates", std::to_string(row.numCandidates)},
{"Skipped allocations", std::to_string(row.numSkipped)}, {"Skipped allocations", std::to_string(row.numSkipped) },
{"Removed allocations", std::to_string(row.numRemoved)}, {"Removed allocations", std::to_string(row.numRemoved) },
{"Saved memory", formatMemory(row.savedBytes)}}; {"Saved memory", formatMemory(row.savedBytes) }
};
printReportFlatFields(os, fields); printReportFlatFields(os, fields);
} }
@@ -87,10 +86,12 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
totalRow.savedBytes += entryTotal.savedBytes; totalRow.savedBytes += entryTotal.savedBytes;
} }
llvm::SmallVector<ReportField, 4> totalFields = {{"Number of candidates", std::to_string(totalRow.numCandidates)}, llvm::SmallVector<ReportField, 4> totalFields = {
{"Skipped allocations", std::to_string(totalRow.numSkipped)}, {"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Removed allocations", std::to_string(totalRow.numRemoved)}, {"Skipped allocations", std::to_string(totalRow.numSkipped) },
{"Saved memory", formatMemory(totalRow.savedBytes)}}; {"Removed allocations", std::to_string(totalRow.numRemoved) },
{"Saved memory", formatMemory(totalRow.savedBytes) }
};
printReportTotalsBlock(os, totalFields); printReportTotalsBlock(os, totalFields);
if (!entries.empty()) if (!entries.empty())
os << "\n"; os << "\n";
@@ -127,15 +128,17 @@ static void emitReport(ArrayRef<CoalescingReportEntry> entries) {
if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) { if (sortedEntries[index].kind == CoalescingReportEntry::Kind::Batch) {
llvm::SmallVector<ReportField, 4> perCoreFields = { llvm::SmallVector<ReportField, 4> perCoreFields = {
{"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)}, {"Number of candidates", std::to_string(sortedEntries[index].row.numCandidates)},
{"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped)}, {"Skipped allocations", std::to_string(sortedEntries[index].row.numSkipped) },
{"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved)}, {"Removed allocations", std::to_string(sortedEntries[index].row.numRemoved) },
{"Saved memory", formatMemory(sortedEntries[index].row.savedBytes)}}; {"Saved memory", formatMemory(sortedEntries[index].row.savedBytes) }
};
CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]); CoalescingReportRow totalRow = getTotalRow(sortedEntries[index]);
llvm::SmallVector<ReportField, 4> totalFields = { llvm::SmallVector<ReportField, 4> totalFields = {
{"Number of candidates", std::to_string(totalRow.numCandidates)}, {"Number of candidates", std::to_string(totalRow.numCandidates)},
{"Skipped allocations", std::to_string(totalRow.numSkipped)}, {"Skipped allocations", std::to_string(totalRow.numSkipped) },
{"Removed allocations", std::to_string(totalRow.numRemoved)}, {"Removed allocations", std::to_string(totalRow.numRemoved) },
{"Saved memory", formatMemory(totalRow.savedBytes)}}; {"Saved memory", formatMemory(totalRow.savedBytes) }
};
printReportPerCoreAndTotalFields(os, perCoreFields, totalFields); printReportPerCoreAndTotalFields(os, perCoreFields, totalFields);
} }
else { else {
@@ -196,8 +199,6 @@ struct StaticMemoryCoalescingPass : PassWrapper<StaticMemoryCoalescingPass, Oper
} // namespace } // namespace
std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { std::unique_ptr<Pass> createPimStaticMemoryCoalescingPass() { return std::make_unique<StaticMemoryCoalescingPass>(); }
return std::make_unique<StaticMemoryCoalescingPass>();
}
} // namespace onnx_mlir } // namespace onnx_mlir
@@ -818,13 +818,14 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
} }
} }
llvm::SmallVector<ReportField, 6> totalFields = {{"Used cores", std::to_string(usedCpuCount)}, llvm::SmallVector<ReportField, 6> totalFields = {
{"Number of top-level compute ops", std::to_string(totalComputeOps)}, {"Used cores", std::to_string(usedCpuCount) },
{"Number of logical computes", std::to_string(totalLogicalComputes)}, {"Number of top-level compute ops", std::to_string(totalComputeOps) },
{"Number of top-level batch compute ops", {"Number of logical computes", std::to_string(totalLogicalComputes) },
std::to_string(totalBatchComputeOps)}, {"Number of top-level batch compute ops", std::to_string(totalBatchComputeOps) },
{"Number of instructions", std::to_string(totalInstructionCount)}, {"Number of instructions", std::to_string(totalInstructionCount)},
{"Number of used crossbars", std::to_string(totalWeightCount)}}; {"Number of used crossbars", std::to_string(totalWeightCount) }
};
printReportTotalsBlock(os, totalFields); printReportTotalsBlock(os, totalFields);
if (!collectedData.empty()) if (!collectedData.empty())
os << "\n"; os << "\n";
@@ -876,13 +877,15 @@ void generateReport(func::FuncOp funcOp, const std::string& name, size_t usedCpu
llvm::SmallVector<ReportField, 3> perCoreFields = { llvm::SmallVector<ReportField, 3> perCoreFields = {
{"Number of logical computes", std::to_string(perCoreLogicalComputeCount)}, {"Number of logical computes", std::to_string(perCoreLogicalComputeCount)},
{"Number of instructions", std::to_string(perCoreInstructionCount)}, {"Number of instructions", std::to_string(perCoreInstructionCount) },
{"Number of used crossbars", std::to_string(perCoreWeightCount)}}; {"Number of used crossbars", std::to_string(perCoreWeightCount) }
};
if (current.isRebatched) { if (current.isRebatched) {
llvm::SmallVector<ReportField, 3> totalEntryFields = { llvm::SmallVector<ReportField, 3> totalEntryFields = {
{"Number of logical computes", std::to_string(current.logicalComputeCount)}, {"Number of logical computes", std::to_string(current.logicalComputeCount)},
{"Number of instructions", std::to_string(totalEntryInstructionCount)}, {"Number of instructions", std::to_string(totalEntryInstructionCount) },
{"Number of used crossbars", std::to_string(current.weightCount)}}; {"Number of used crossbars", std::to_string(current.weightCount) }
};
printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields); printReportPerCoreAndTotalFields(os, perCoreFields, totalEntryFields);
} }
else { else {
@@ -66,8 +66,10 @@ static Value buildSubviewChunk(const StaticSubviewInfo& info,
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
} }
static SmallVector<Value> static SmallVector<Value> delinearizeIndexValue(Value linearIndex,
delinearizeIndexValue(Value linearIndex, ArrayRef<int64_t> shape, ArrayRef<int64_t> strides, PatternRewriter& rewriter) { ArrayRef<int64_t> shape,
ArrayRef<int64_t> strides,
PatternRewriter& rewriter) {
SmallVector<Value> indices; SmallVector<Value> indices;
indices.reserve(shape.size()); indices.reserve(shape.size());
@@ -112,7 +114,8 @@ static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides"); assert(info.strides[dim] == 1 && "loop-based subview rewrite requires unit strides");
chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter)); chunkOffsets.push_back(addDynamicOffset(info.offsets[dim], outerIndices[dim], rewriter));
chunkSizes.push_back(rewriter.getIndexAttr(1)); chunkSizes.push_back(rewriter.getIndexAttr(1));
} else { }
else {
chunkOffsets.push_back(info.offsets[dim]); chunkOffsets.push_back(info.offsets[dim]);
chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back())); chunkSizes.push_back(rewriter.getIndexAttr(info.sizes.back()));
} }
@@ -122,11 +125,8 @@ static Value buildDynamicSubviewChunk(const StaticSubviewInfo& info,
return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides); return memref::SubViewOp::create(rewriter, loc, info.source, chunkOffsets, chunkSizes, chunkStrides);
} }
static Value buildContiguousChunk(Value source, static Value buildContiguousChunk(
ArrayRef<int64_t> copyShape, Value source, ArrayRef<int64_t> copyShape, ArrayRef<Value> outerIndices, Location loc, PatternRewriter& rewriter) {
ArrayRef<Value> outerIndices,
Location loc,
PatternRewriter& rewriter) {
SmallVector<OpFoldResult> chunkOffsets; SmallVector<OpFoldResult> chunkOffsets;
SmallVector<OpFoldResult> chunkSizes; SmallVector<OpFoldResult> chunkSizes;
SmallVector<OpFoldResult> chunkStrides; SmallVector<OpFoldResult> chunkStrides;
@@ -203,7 +203,8 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp,
rewriter.setInsertionPointToStart(loop.getBody()); rewriter.setInsertionPointToStart(loop.getBody());
SmallVector<Value> outerIndices = SmallVector<Value> outerIndices =
outerShape.empty() ? SmallVector<Value> {} : delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter); outerShape.empty() ? SmallVector<Value> {}
: delinearizeIndexValue(loop.getInductionVar(), outerShape, outerStrides, rewriter);
Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter) Value chunkDst = splitDst ? buildDynamicSubviewChunk(*dstSubview, outerIndices, copyOp.getLoc(), rewriter)
: buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter); : buildContiguousChunk(dst, copyShape, outerIndices, copyOp.getLoc(), rewriter);
Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter) Value chunkSrc = splitSrc ? buildDynamicSubviewChunk(*srcSubview, outerIndices, copyOp.getLoc(), rewriter)
+1 -1
View File
@@ -6,10 +6,10 @@
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Common/PimCommon.hpp"
#include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp"
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
#include "src/Accelerators/PIM/Common/IR/SubviewUtils.hpp"
using namespace mlir; using namespace mlir;