diff --git a/src/PIM/Common/IR/AddressAnalysis.cpp b/src/PIM/Common/IR/AddressAnalysis.cpp index 4257ac8..a8daf58 100644 --- a/src/PIM/Common/IR/AddressAnalysis.cpp +++ b/src/PIM/Common/IR/AddressAnalysis.cpp @@ -264,7 +264,7 @@ llvm::FailureOr resolveContiguousAddressImpl(mlir::Va return mlir::failure(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); - byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; + byteOffset += linearizeIndex(offsets, sourceStrides) * getElementTypeSizeInBytes(subviewType.getElementType()); value = resolveAlias(subviewOp.getSource(), knowledge); continue; } diff --git a/src/PIM/Common/IR/ShapeUtils.cpp b/src/PIM/Common/IR/ShapeUtils.cpp index 33253cb..0e48410 100644 --- a/src/PIM/Common/IR/ShapeUtils.cpp +++ b/src/PIM/Common/IR/ShapeUtils.cpp @@ -1,4 +1,5 @@ #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" #include "src/Accelerators/PIM/Common/IR/ShapeUtils.hpp" @@ -35,6 +36,30 @@ int64_t getNumElements(llvm::ArrayRef shape) { return numElements; } +bool hasByteSizedElementType(mlir::Type elementType) { + if (mlir::isa(elementType)) + return true; + if (auto intType = mlir::dyn_cast(elementType)) + return intType.getWidth() > 0 && intType.getWidth() % 8 == 0; + if (auto floatType = mlir::dyn_cast(elementType)) + return floatType.getWidth() > 0 && floatType.getWidth() % 8 == 0; + return false; +} + +size_t getElementTypeSizeInBytes(mlir::Type elementType) { + if (mlir::isa(elementType)) + return mlir::IndexType::kInternalStorageBitWidth / 8; + if (auto intType = mlir::dyn_cast(elementType)) + return static_cast(intType.getWidth() / 8); + if (auto floatType = mlir::dyn_cast(elementType)) + return static_cast(floatType.getWidth() / 8); + llvm_unreachable("expected byte-sized integer, float, or index element type"); +} + +size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType) { + return static_cast(shapedType.getNumElements()) * getElementTypeSizeInBytes(shapedType.getElementType()); +} + bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef offsets, llvm::ArrayRef sizes, diff --git a/src/PIM/Common/IR/ShapeUtils.hpp b/src/PIM/Common/IR/ShapeUtils.hpp index 41d666a..214846f 100644 --- a/src/PIM/Common/IR/ShapeUtils.hpp +++ b/src/PIM/Common/IR/ShapeUtils.hpp @@ -1,8 +1,13 @@ #pragma once +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include + namespace onnx_mlir { llvm::SmallVector computeRowMajorStrides(llvm::ArrayRef shape); @@ -14,6 +19,12 @@ int64_t linearizeIndex(llvm::ArrayRef indices, llvm::ArrayRef int64_t getNumElements(llvm::ArrayRef shape); +bool hasByteSizedElementType(mlir::Type elementType); + +size_t getElementTypeSizeInBytes(mlir::Type elementType); + +size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType); + bool isMemoryContiguous(llvm::ArrayRef srcShape, llvm::ArrayRef offsets, llvm::ArrayRef sizes, diff --git a/src/PIM/Compiler/PimCodeGen.cpp b/src/PIM/Compiler/PimCodeGen.cpp index 64e12df..2e26140 100644 --- a/src/PIM/Compiler/PimCodeGen.cpp +++ b/src/PIM/Compiler/PimCodeGen.cpp @@ -41,23 +41,10 @@ using namespace mlir; using namespace onnx_mlir; using namespace onnx_mlir::compact_asm; -static size_t getElementTypeSizeInBytes(mlir::Type elementType) { - if (elementType.isIndex()) - return sizeof(int64_t); - if (elementType.isIntOrFloat()) - return elementType.getIntOrFloatBitWidth() / 8; - llvm_unreachable("unsupported shaped element type"); -} - -static size_t getValueSizeInBytes(mlir::Value value) { - auto type = cast(value.getType()); - return type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()); -} - MemEntry* PimMemory::gatherMemEntry(mlir::Value value) { auto type = cast(value.getType()); assert("Only static shape is supported" && type.hasStaticShape()); - size_t allocSize = type.getNumElements() * getElementTypeSizeInBytes(type.getElementType()); + size_t allocSize = getShapedTypeSizeInBytes(type); MemEntry memEntry = {0, allocSize}; return &memEntries.emplace_back(memEntry, value).first; } @@ -450,7 +437,8 @@ void PimCodeGen::codeGenReceiveOp(pim::PimReceiveOp receiveOp, const StaticValue void PimCodeGen::codeGenReceiveTensorOp(pim::PimReceiveTensorOp receiveTensorOp, const StaticValueKnowledge& knowledge) const { size_t outputAddr = addressOf(receiveTensorOp.getOutputBuffer(), knowledge); - size_t chunkSize = getValueSizeInBytes(receiveTensorOp.getOutputBuffer()) / receiveTensorOp.getSourceCoreIds().size(); + size_t chunkSize = getShapedTypeSizeInBytes(cast(receiveTensorOp.getOutputBuffer().getType())) + / receiveTensorOp.getSourceCoreIds().size(); for (auto [chunkIndex, sourceCoreId] : llvm::enumerate(receiveTensorOp.getSourceCoreIds())) emitCommunicationOp("recv", outputAddr + chunkIndex * chunkSize, sourceCoreId, chunkSize); } @@ -463,7 +451,8 @@ void PimCodeGen::codeGenSendOp(pim::PimSendOp sendOp, const StaticValueKnowledge void PimCodeGen::codeGenSendTensorOp(pim::PimSendTensorOp sendTensorOp, const StaticValueKnowledge& knowledge) const { size_t inputAddr = addressOf(sendTensorOp.getInput(), knowledge); - size_t chunkSize = getValueSizeInBytes(sendTensorOp.getInput()) / sendTensorOp.getTargetCoreIds().size(); + size_t chunkSize = getShapedTypeSizeInBytes(cast(sendTensorOp.getInput().getType())) + / sendTensorOp.getTargetCoreIds().size(); for (auto [chunkIndex, targetCoreId] : llvm::enumerate(sendTensorOp.getTargetCoreIds())) emitCommunicationOp("send", inputAddr + chunkIndex * chunkSize, targetCoreId, chunkSize); } @@ -474,7 +463,7 @@ void PimCodeGen::codeGenConcatOp(pim::PimConcatOp concatOp, const StaticValueKno int64_t axis = concatOp.getAxis(); ArrayRef outputShape = outputType.getShape(); - size_t elementSize = outputType.getElementTypeBitWidth() / 8; + size_t elementSize = getElementTypeSizeInBytes(outputType.getElementType()); size_t outputAddr = addressOf(concatOp.getOutputBuffer(), knowledge); size_t outerCount = 1; @@ -526,7 +515,7 @@ void PimCodeGen::codeGenVVAddOp(pim::PimVVAddOp vvaddOp, const StaticValueKnowle instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getValueSizeInBytes(vvaddOp.getLhs())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvaddOp.getLhs().getType()))); emitInstruction(instruction); } @@ -541,7 +530,7 @@ void PimCodeGen::codeGenVVSubOp(pim::PimVVSubOp vvsubOp, const StaticValueKnowle instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getValueSizeInBytes(vvsubOp.getLhs())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvsubOp.getLhs().getType()))); emitInstruction(instruction); } @@ -556,7 +545,7 @@ void PimCodeGen::codeGenVVMulOp(pim::PimVVMulOp vvmulOp, const StaticValueKnowle instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getValueSizeInBytes(vvmulOp.getLhs())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvmulOp.getLhs().getType()))); emitInstruction(instruction); } @@ -571,7 +560,7 @@ void PimCodeGen::codeGenVVMaxOp(pim::PimVVMaxOp vvmaxOp, const StaticValueKnowle instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getValueSizeInBytes(vvmaxOp.getLhs())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvmaxOp.getLhs().getType()))); emitInstruction(instruction); } @@ -586,7 +575,7 @@ void PimCodeGen::codeGenVVDMulOp(pim::PimVVDMulOp vvdmulOp, const StaticValueKno instruction.rd = 0; instruction.r1 = 1; instruction.r2OrImm = 2; - instruction.generic3 = static_cast(getValueSizeInBytes(vvdmulOp.getLhs())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vvdmulOp.getLhs().getType()))); emitInstruction(instruction); } @@ -601,7 +590,7 @@ void PimCodeGen::codeGenVAvgOp(pim::PimVAvgOp vavgOp, const StaticValueKnowledge instruction.r1 = 1; instruction.r2OrImm = 1; instruction.generic1 = 1; - instruction.generic3 = static_cast(getValueSizeInBytes(vavgOp.getInput())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vavgOp.getInput().getType()))); emitInstruction(instruction); } @@ -614,7 +603,7 @@ void PimCodeGen::codeGenVReluOp(pim::PimVReluOp vreluOp, const StaticValueKnowle instruction.opcode = pim_binary::Opcode::vrelu; instruction.rd = 0; instruction.r1 = 1; - instruction.generic3 = static_cast(getValueSizeInBytes(vreluOp.getInput())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vreluOp.getInput().getType()))); emitInstruction(instruction); } @@ -627,7 +616,7 @@ void PimCodeGen::codeGenVTanhOp(pim::PimVTanhOp vtanhOp, const StaticValueKnowle instruction.opcode = pim_binary::Opcode::vtanh; instruction.rd = 0; instruction.r1 = 1; - instruction.generic3 = static_cast(getValueSizeInBytes(vtanhOp.getInput())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vtanhOp.getInput().getType()))); emitInstruction(instruction); } @@ -640,7 +629,7 @@ void PimCodeGen::codeGenVSigmOp(pim::PimVSigmOp vsigmOp, const StaticValueKnowle instruction.opcode = pim_binary::Opcode::vsigm; instruction.rd = 0; instruction.r1 = 1; - instruction.generic3 = static_cast(getValueSizeInBytes(vsigmOp.getInput())); + instruction.generic3 = static_cast(getShapedTypeSizeInBytes(cast(vsigmOp.getInput().getType()))); emitInstruction(instruction); } @@ -653,7 +642,8 @@ void PimCodeGen::codeGenVSoftmaxOp(pim::PimVSoftmaxOp vsoftmaxOp, const StaticVa instruction.opcode = pim_binary::Opcode::vsoftmax; instruction.rd = 0; instruction.r1 = 1; - instruction.generic3 = static_cast(getValueSizeInBytes(vsoftmaxOp.getInput())); + instruction.generic3 = + static_cast(getShapedTypeSizeInBytes(cast(vsoftmaxOp.getInput().getType()))); emitInstruction(instruction); } @@ -666,7 +656,7 @@ void PimCodeGen::codeGenTransposeOp(pim::PimTransposeOp transposeOp, const Stati auto srcType = cast(transposeOp.getInput().getType()); auto srcShape = srcType.getShape(); size_t rank = srcShape.size(); - size_t elementSize = srcType.getElementTypeBitWidth() / 8; + size_t elementSize = getElementTypeSizeInBytes(srcType.getElementType()); size_t totalElements = srcType.getNumElements(); // Read permutation. Destination dim i corresponds to source dim perm[i]. diff --git a/src/PIM/Compiler/PimWeightEmitter.cpp b/src/PIM/Compiler/PimWeightEmitter.cpp index 94ea43a..971fa47 100644 --- a/src/PIM/Compiler/PimWeightEmitter.cpp +++ b/src/PIM/Compiler/PimWeightEmitter.cpp @@ -208,7 +208,7 @@ createAndPopulateWeightFolder(func::FuncOp funcOp, StringRef outputDirPath) { int64_t numCols = shape[1]; assert(numRows <= xbarSize && numCols <= xbarSize && "Weight dimensions must not exceed crossbar size"); - size_t elementByteWidth = denseAttr.getElementType().getIntOrFloatBitWidth() / 8; + size_t elementByteWidth = getElementTypeSizeInBytes(denseAttr.getElementType()); std::string newFileName = "crossbar_" + std::to_string(indexFileName++) + ".bin"; auto weightFilePath = (coreWeightsDirPath + "/" + newFileName).str(); diff --git a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp index 3bd61c7..20511b3 100644 --- a/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp +++ b/src/PIM/Conversion/ONNXToSpatial/Patterns/Math/Gemm.cpp @@ -427,11 +427,7 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp, auto inputArg = computeOp.getInputArgument(aHSliceId); if (!weightArg || !inputArg) return failure(); - vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, - gemmLoc, - currOutHSliceType, - *weightArg, - *inputArg)); + vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter, gemmLoc, currOutHSliceType, *weightArg, *inputArg)); } if (vmmOutputs.empty()) { gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs"); diff --git a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp index 6e93896..f863957 100644 --- a/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp +++ b/src/PIM/Conversion/SpatialToPim/BatchCoreLoweringPatterns.cpp @@ -121,7 +121,7 @@ static Value createHostTargetOffset(IRRewriter& rewriter, tensor::ParallelInsertSliceOp insertSlice, ShapedType destinationType, IRMapping& mapper) { - int64_t elementBytes = destinationType.getElementTypeBitWidth() / 8; + int64_t elementBytes = static_cast(getElementTypeSizeInBytes(destinationType.getElementType())); SmallVector strides(destinationType.getRank(), 1); ArrayRef shape = destinationType.getShape(); for (int64_t dim = destinationType.getRank() - 2; dim >= 0; --dim) diff --git a/src/PIM/Conversion/SpatialToPim/Common.cpp b/src/PIM/Conversion/SpatialToPim/Common.cpp index 20891fb..eea9698 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.cpp +++ b/src/PIM/Conversion/SpatialToPim/Common.cpp @@ -55,10 +55,6 @@ size_t getSliceActualOffset(tensor::ExtractSliceOp& sliceOp, ShapedType& inputSh return returnValue; } -size_t getShapedTypeSizeInBytes(ShapedType shapedType) { - return shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; -} - IntegerAttr getTensorSizeInBytesAttr(Builder& builder, mlir::Value value) { return builder.getI32IntegerAttr(static_cast(getShapedTypeSizeInBytes(cast(value.getType())))); } diff --git a/src/PIM/Conversion/SpatialToPim/Common.hpp b/src/PIM/Conversion/SpatialToPim/Common.hpp index f0a09e9..7377000 100644 --- a/src/PIM/Conversion/SpatialToPim/Common.hpp +++ b/src/PIM/Conversion/SpatialToPim/Common.hpp @@ -20,8 +20,6 @@ namespace onnx_mlir { */ size_t getSliceActualOffset(mlir::tensor::ExtractSliceOp& sliceOp, mlir::ShapedType& inputShape); -size_t getShapedTypeSizeInBytes(mlir::ShapedType shapedType); - mlir::IntegerAttr getTensorSizeInBytesAttr(mlir::Builder& builder, mlir::Value value); template diff --git a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp index 859e4af..b258253 100644 --- a/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp +++ b/src/PIM/Conversion/SpatialToPim/ReturnPathNormalization.cpp @@ -433,7 +433,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low markOpToRemove(op); auto storedType = cast(currentStoredValue.getType()); - size_t elementSize = storedType.getElementTypeBitWidth() / 8; + size_t elementSize = getElementTypeSizeInBytes(storedType.getElementType()); if (auto storedOp = currentStoredValue.getDefiningOp()) rewriter.setInsertionPointAfter(storedOp); Value outputTensor = outputTensors[returnUse->returnIndex](rewriter, loc); @@ -455,7 +455,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low if (isa(resultUser)) { size_t resultIndexInReturn = resultUse.getOperandNumber(); - size_t elementSize = storedTensorType.getElementType().getIntOrFloatBitWidth() / 8; + size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType()); rewriter.setInsertionPointAfterValue(storedValue); Value outputTensor = outputTensors[resultIndexInReturn](rewriter, loc); emitHostCopy(rewriter, @@ -471,7 +471,7 @@ raptor::SpatialToPimPass::ReturnPathLoweringResult raptor::SpatialToPimPass::low } if (auto concatReturnUse = analyzeConcatReturnUse(producedValue)) { - size_t elementSize = storedTensorType.getElementTypeBitWidth() / 8; + size_t elementSize = getElementTypeSizeInBytes(storedTensorType.getElementType()); for (Operation* concatOp : concatReturnUse->concatChain) markOpToRemove(concatOp); diff --git a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp index de992da..5a5a1d1 100644 --- a/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp +++ b/src/PIM/Conversion/SpatialToPim/SpatialToPimPass.cpp @@ -325,9 +325,9 @@ LogicalResult raptor::SpatialToPimPass::allocateAndInitializeCoreLocalVariables( auto insertMemCopyHostToDev = [&](Value inputTensor, int64_t elementsOffset) { auto tensorType = cast(inputTensor.getType()); Type elementType = tensorType.getElementType(); - if (!elementType.isIntOrFloat()) + if (!hasByteSizedElementType(elementType)) return; - size_t elementByteSize = elementType.getIntOrFloatBitWidth() / 8; + size_t elementByteSize = getElementTypeSizeInBytes(elementType); rewriter.setInsertionPointAfter(inputTensor.getDefiningOp()); auto deviceTensor = tensor::EmptyOp::create(rewriter, loc, tensorType.getShape(), elementType); diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp index d0c90f7..f7fcac9 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/BufferizationUtils.cpp @@ -17,7 +17,7 @@ Value materializeContiguousMemRef(Value memrefValue, Location loc, RewriterBase& auto shapedType = cast(memrefValue.getType()); auto contiguousType = MemRefType::get(shapedType.getShape(), shapedType.getElementType()); Value contiguousBuffer = memref::AllocOp::create(rewriter, loc, contiguousType); - auto sizeInBytes = shapedType.getNumElements() * shapedType.getElementTypeBitWidth() / 8; + auto sizeInBytes = getShapedTypeSizeInBytes(shapedType); return PimMemCopyOp::create(rewriter, loc, diff --git a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp index 00ab90f..395806d 100644 --- a/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp +++ b/src/PIM/Dialect/Pim/Transforms/Bufferization/Common.cpp @@ -1,9 +1,10 @@ #include "Dialect/Pim/Transforms/Bufferization/Common.hpp" +#include "src/Accelerators/PIM/Common/PimCommon.hpp" using namespace mlir; IntegerAttr onnx_mlir::pim::getMemRefSizeInBytesAttr(OpBuilder& builder, Value memref) { auto type = mlir::cast(memref.getType()); - int32_t sizeInBytes = static_cast(type.getNumElements() * type.getElementTypeBitWidth() / 8); + int32_t sizeInBytes = static_cast(getShapedTypeSizeInBytes(type)); return builder.getI32IntegerAttr(sizeInBytes); } diff --git a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp index 45d7a70..1e49538 100644 --- a/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp +++ b/src/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.cpp @@ -9,6 +9,7 @@ #include +#include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/Transforms/StaticMemoryCoalescing/StaticMemoryCoalescing.hpp" using namespace mlir; @@ -23,11 +24,12 @@ static bool isSupportedAliasOp(Operation* op) { } static bool isCandidateAllocType(MemRefType type) { - return type && type.hasStaticShape() && type.getLayout().isIdentity() && type.getElementTypeBitWidth() > 0; + return type && type.hasStaticShape() && type.getLayout().isIdentity() + && hasByteSizedElementType(type.getElementType()); } static uint64_t getTypeSizeBytes(MemRefType type) { - return static_cast(type.getNumElements() * type.getElementTypeBitWidth() / 8); + return static_cast(type.getNumElements() * getElementTypeSizeInBytes(type.getElementType())); } static FailureOr diff --git a/src/PIM/Dialect/Spatial/SpatialOps.cpp b/src/PIM/Dialect/Spatial/SpatialOps.cpp index 4c650f3..0ced6d2 100644 --- a/src/PIM/Dialect/Spatial/SpatialOps.cpp +++ b/src/PIM/Dialect/Spatial/SpatialOps.cpp @@ -34,7 +34,9 @@ void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t i } // namespace -std::optional SpatCompute::getWeightArgument(unsigned idx) { return getBatchBodyArgument(getBody(), idx); } +std::optional SpatCompute::getWeightArgument(unsigned idx) { + return getBatchBodyArgument(getBody(), idx); +} std::optional SpatCompute::getInputArgument(unsigned idx) { return getBatchBodyArgument(getBody(), getWeights().size() + idx); @@ -74,11 +76,13 @@ SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Locat resultTypes.insert(resultTypes.begin() + idx, type); auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs()); newCompute->setAttrs((*this)->getAttrs()); - setComputeOperandSegmentSizes( - newCompute.getOperation(), static_cast(newCompute.getWeights().size()), static_cast(newCompute.getInputs().size())); + setComputeOperandSegmentSizes(newCompute.getOperation(), + static_cast(newCompute.getWeights().size()), + static_cast(newCompute.getInputs().size())); rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end()); for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx) - getResult(oldResultIdx).replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); + getResult(oldResultIdx) + .replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); rewriter.eraseOp(getOperation()); return std::make_tuple(cast(newCompute.getResult(idx)), newCompute); } @@ -110,7 +114,8 @@ std::optional SpatComputeBatch::getOutputArgument(unsigned idx) { return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx); } -std::optional> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { +std::optional> +SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) { unsigned weightCount = getWeights().size(); unsigned inputCount = getInputs().size(); getOperation()->insertOperands(idx, ValueRange {weight}); @@ -145,8 +150,9 @@ SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, auto newBatch = SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs()); newBatch->setAttrs((*this)->getAttrs()); - setComputeOperandSegmentSizes( - newBatch.getOperation(), static_cast(newBatch.getWeights().size()), static_cast(newBatch.getInputs().size())); + setComputeOperandSegmentSizes(newBatch.getOperation(), + static_cast(newBatch.getWeights().size()), + static_cast(newBatch.getInputs().size())); rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end()); if (newBatch.getBody().empty()) { rewriter.eraseOp(newBatch); @@ -155,7 +161,8 @@ SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, auto blockArg = newBatch.getBody().front().insertArgument( 1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc); for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx) - getResult(oldResultIdx).replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); + getResult(oldResultIdx) + .replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1)); rewriter.eraseOp(getOperation()); return std::make_tuple(cast(newBatch.getResult(idx)), blockArg, newBatch); } diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp index bdc8ca1..272c9df 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MaterializeMergeSchedule.cpp @@ -4,7 +4,6 @@ #include "mlir/IR/IRMapping.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/FoldUtils.h" #include "llvm/ADT/DenseMap.h" @@ -131,7 +130,7 @@ struct MaterializerState { DenseMap cpuSlotToInstance; DenseSet materializedSlots; - DenseMap, ProducerKeyInfo> producerDestClasses; + DenseMap, ProducerKeyInfo> producerDestClasses; DenseMap, ProducerKeyInfo> availableValues; DenseMap hostReplacements; DenseSet oldComputeOps; @@ -574,6 +573,77 @@ SmallVector createIndexConstants(MaterializerState& state, Operation* return createIndexConstants(state, anchor, ArrayRef(widened)); } +Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { + SmallVector elements; + elements.reserve(values.size()); + for (int64_t value : values) + elements.push_back(APInt(64, value)); + + auto type = RankedTensorType::get({static_cast(values.size())}, state.rewriter.getIndexType()); + auto attr = DenseIntElementsAttr::get(type, elements); + return getOrCreateHostConstant(anchor, attr, type, state.constantFolder); +} + +Value createIndexTensorConstant(MaterializerState& state, Operation* anchor, ArrayRef values) { + SmallVector widened; + widened.reserve(values.size()); + for (int32_t value : values) + widened.push_back(value); + return createIndexTensorConstant(state, anchor, ArrayRef(widened)); +} + +bool allEqual(ArrayRef values) { + assert(!values.empty() && "expected at least one value"); + for (int64_t value : values.drop_front()) + if (value != values.front()) + return false; + return true; +} + +bool allEqual(ArrayRef values) { + assert(!values.empty() && "expected at least one value"); + for (int32_t value : values.drop_front()) + if (value != values.front()) + return false; + return true; +} + +Value createLaneIndexedIndexValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef values, + Location loc) { + assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); + assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); + + if (allEqual(values)) + return createIndexConstant(state, materializedClass.op, values.front()); + + auto batch = cast(materializedClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected compute_batch lane argument"); + + Value table = createIndexTensorConstant(state, materializedClass.op, values); + return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); +} + +Value createLaneIndexedIndexValue(MaterializerState& state, + MaterializedClass& materializedClass, + ArrayRef values, + Location loc) { + assert(materializedClass.isBatch && "lane-indexed value requires a materialized batch class"); + assert(values.size() == materializedClass.cpus.size() && "expected one value per materialized batch lane"); + + if (allEqual(values)) + return createIndexConstant(state, materializedClass.op, values.front()); + + auto batch = cast(materializedClass.op); + auto laneArg = batch.getLaneArgument(); + assert(laneArg && "expected compute_batch lane argument"); + + Value table = createIndexTensorConstant(state, materializedClass.op, values); + return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); +} + FailureOr> getPeerInstances(MaterializerState& state, const MaterializedClass& materializedClass, SlotId slot) { SmallVector peers; @@ -623,14 +693,12 @@ Value createOriginalLaneValue(MaterializerState& state, return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult(); } - SmallVector laneValues; + SmallVector laneValues; laneValues.reserve(peers.size()); for (const ComputeInstance& peer : peers) - laneValues.push_back(APInt(64, peer.laneStart)); + laneValues.push_back(peer.laneStart); - auto tableType = RankedTensorType::get({static_cast(peers.size())}, state.rewriter.getIndexType()); - auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues); - Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult(); + Value table = createIndexTensorConstant(state, materializedClass.op, laneValues); return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult(); } @@ -659,6 +727,12 @@ bool hasLiveExternalUse(Value value, const DenseSet& oldComputeOps) return false; } +void appendDestinationClass(MaterializerState& state, ProducerKey key, ClassId classId) { + SmallVector& destinations = state.producerDestClasses[key]; + if (!llvm::is_contained(destinations, classId)) + destinations.push_back(classId); +} + void replaceLiveExternalUses(Value oldValue, Value replacement, const DenseSet& oldComputeOps) { SmallVector uses; for (OpOperand& use : oldValue.getUses()) @@ -693,7 +767,7 @@ LogicalResult collectProducerDestinations(MaterializerState& state) { if (sourceClass == targetClass) continue; - state.producerDestClasses[producerKey].insert(targetClass); + appendDestinationClass(state, producerKey, targetClass); } } } @@ -714,29 +788,70 @@ bool haveSameDestinationClasses(MaterializerState& state, ArrayRef return true; auto firstIt = state.producerDestClasses.find(keys.front()); - DenseSet empty; - const DenseSet& first = firstIt == state.producerDestClasses.end() ? empty : firstIt->second; + ArrayRef first = firstIt == state.producerDestClasses.end() ? ArrayRef() : firstIt->second; for (ProducerKey key : keys.drop_front()) { auto it = state.producerDestClasses.find(key); - const DenseSet& current = it == state.producerDestClasses.end() ? empty : it->second; + ArrayRef current = it == state.producerDestClasses.end() ? ArrayRef() : it->second; if (first.size() != current.size()) return false; - for (ClassId classId : first) - if (!current.contains(classId)) + for (auto [lhs, rhs] : llvm::zip(first, current)) + if (lhs != rhs) return false; } return true; } -SmallVector getSortedDestinationClasses(MaterializerState& state, ProducerKey key) { - SmallVector destinations; +ArrayRef getDestinationClasses(MaterializerState& state, ProducerKey key) { auto it = state.producerDestClasses.find(key); if (it == state.producerDestClasses.end()) - return destinations; - for (ClassId classId : it->second) - destinations.push_back(classId); - llvm::sort(destinations); - return destinations; + return {}; + return it->second; +} + +void appendSend(MaterializerState& state, + MaterializedClass& sourceClass, + Value payload, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); + assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); + assert(!channelIds.empty() && "expected at least one send"); + + state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); + + if (sourceClass.isBatch) { + Value channelId = createLaneIndexedIndexValue(state, sourceClass, channelIds, loc); + Value sourceCoreId = createLaneIndexedIndexValue(state, sourceClass, sourceCoreIds, loc); + Value targetCoreId = createLaneIndexedIndexValue(state, sourceClass, targetCoreIds, loc); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + return; + } + + for (auto index : llvm::seq(0, channelIds.size())) { + Value channelId = createIndexConstant(state, sourceClass.op, channelIds[index]); + Value sourceCoreId = createIndexConstant(state, sourceClass.op, sourceCoreIds[index]); + Value targetCoreId = createIndexConstant(state, sourceClass.op, targetCoreIds[index]); + SpatChannelSendOp::create(state.rewriter, loc, channelId, sourceCoreId, targetCoreId, payload); + } +} + +Value appendScalarReceive(MaterializerState& state, + MaterializedClass& targetClass, + Type type, + int64_t channelId, + int32_t sourceCoreId, + int32_t targetCoreId, + Location loc) { + assert(!targetClass.isBatch && "scalar receive helper expects a scalar target class"); + + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + Value channelIdValue = createIndexConstant(state, targetClass.op, channelId); + Value sourceCoreIdValue = createIndexConstant(state, targetClass.op, sourceCoreId); + Value targetCoreIdValue = createIndexConstant(state, targetClass.op, targetCoreId); + return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelIdValue, sourceCoreIdValue, targetCoreIdValue) + .getOutput(); } Value appendReceive(MaterializerState& state, @@ -746,50 +861,169 @@ Value appendReceive(MaterializerState& state, ArrayRef sourceCoreIds, ArrayRef targetCoreIds, Location loc) { + assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); + assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); + assert(!channelIds.empty() && "expected at least one receive"); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); - SmallVector channelIdValues = createIndexConstants(state, targetClass.op, channelIds); - SmallVector sourceCoreIdValues = createIndexConstants(state, targetClass.op, sourceCoreIds); - SmallVector targetCoreIdValues = createIndexConstants(state, targetClass.op, targetCoreIds); if (targetClass.isBatch) { - return SpatChannelReceiveBatchOp::create( - state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) - .getOutput(); + Value channelId = createLaneIndexedIndexValue(state, targetClass, channelIds, loc); + Value sourceCoreId = createLaneIndexedIndexValue(state, targetClass, sourceCoreIds, loc); + Value targetCoreId = createLaneIndexedIndexValue(state, targetClass, targetCoreIds, loc); + return SpatChannelReceiveOp::create(state.rewriter, loc, type, channelId, sourceCoreId, targetCoreId).getOutput(); } - if (channelIds.size() != 1) { - return SpatChannelReceiveTensorOp::create( - state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) - .getOutput(); - } - - return SpatChannelReceiveOp::create( - state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front()) - .getOutput(); + assert(channelIds.size() == 1 && "scalar target class can only receive one message at a time"); + return appendScalarReceive( + state, targetClass, type, channelIds.front(), sourceCoreIds.front(), targetCoreIds.front(), loc); } -Value appendHostReceive(MaterializerState& state, - MaterializedClass& sourceClass, - Type type, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { - state.rewriter.setInsertionPointAfter(sourceClass.op); - SmallVector channelIdValues = createIndexConstants(state, sourceClass.op, channelIds); - SmallVector sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds); - SmallVector targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds); +Value appendPackedScalarReceives(MaterializerState& state, + MaterializedClass& targetClass, + Type fragmentType, + Type packedType, + ArrayRef channelIds, + ArrayRef sourceCoreIds, + ArrayRef targetCoreIds, + Location loc) { + assert(!targetClass.isBatch && "packed scalar receive helper expects a scalar target class"); + assert(channelIds.size() == sourceCoreIds.size() && "channel/source count mismatch"); + assert(channelIds.size() == targetCoreIds.size() && "channel/target count mismatch"); + assert(!channelIds.empty() && "expected at least one receive"); - if (sourceClass.isBatch) { - return SpatChannelReceiveTensorOp::create( - state.rewriter, loc, type, channelIdValues, sourceCoreIdValues, targetCoreIdValues) - .getOutput(); + SmallVector fragments; + fragments.reserve(channelIds.size()); + for (auto index : llvm::seq(0, channelIds.size())) { + fragments.push_back(appendScalarReceive( + state, targetClass, fragmentType, channelIds[index], sourceCoreIds[index], targetCoreIds[index], loc)); } - assert(channelIds.size() == 1 && "scalar host receive expects one channel"); - return SpatChannelReceiveOp::create( - state.rewriter, loc, type, channelIdValues.front(), sourceCoreIdValues.front(), targetCoreIdValues.front()) - .getOutput(); + state.rewriter.setInsertionPoint(targetClass.body->getTerminator()); + + Value packed = fragments.front(); + if (fragments.size() != 1) + packed = tensor::ConcatOp::create(state.rewriter, loc, 0, ValueRange(fragments)).getResult(); + + if (packed.getType() != packedType) + packed = tensor::CastOp::create(state.rewriter, loc, packedType, packed).getResult(); + + return packed; +} + +LogicalResult emitClassToClassCommunication(MaterializerState& state, + MaterializedClass& sourceClass, + MaterializedClass& targetClass, + ArrayRef keys, + Value payload, + Location loc) { + if (sourceClass.id == targetClass.id) { + for (ProducerKey key : keys) + state.availableValues[key][targetClass.id] = payload; + return success(); + } + + if (!sourceClass.isBatch && !targetClass.isBatch) { + int64_t channelId = state.nextChannelId++; + int32_t sourceCpu = static_cast(sourceClass.cpus.front()); + int32_t targetCpu = static_cast(targetClass.cpus.front()); + + SmallVector channelIds {channelId}; + SmallVector sourceCoreIds {sourceCpu}; + SmallVector targetCoreIds {targetCpu}; + + appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = + appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + + for (ProducerKey key : keys) + state.availableValues[key][targetClass.id] = received; + return success(); + } + + if (!sourceClass.isBatch && targetClass.isBatch) { + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(targetClass.cpus.size()); + sourceCoreIds.reserve(targetClass.cpus.size()); + targetCoreIds.reserve(targetClass.cpus.size()); + + int32_t sourceCpu = static_cast(sourceClass.cpus.front()); + for (CpuId targetCpu : targetClass.cpus) { + channelIds.push_back(state.nextChannelId++); + sourceCoreIds.push_back(sourceCpu); + targetCoreIds.push_back(static_cast(targetCpu)); + } + + appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = + appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + + for (ProducerKey key : keys) + state.availableValues[key][targetClass.id] = received; + return success(); + } + + if (sourceClass.isBatch && !targetClass.isBatch) { + std::optional packedKey = getContiguousProducerKeyForKeys(keys); + if (!packedKey) + return sourceClass.op->emitError( + "cannot materialize batch-to-scalar communication because source lanes are not contiguous"); + + FailureOr packedType = getPackedBatchTensorType(payload.getType(), keys.size()); + if (failed(packedType)) + return sourceClass.op->emitError( + "cannot materialize batch-to-scalar communication for non-static ranked tensor payload"); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(sourceClass.cpus.size()); + sourceCoreIds.reserve(sourceClass.cpus.size()); + targetCoreIds.reserve(sourceClass.cpus.size()); + + int32_t targetCpu = static_cast(targetClass.cpus.front()); + for (CpuId sourceCpu : sourceClass.cpus) { + channelIds.push_back(state.nextChannelId++); + sourceCoreIds.push_back(static_cast(sourceCpu)); + targetCoreIds.push_back(targetCpu); + } + + appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = appendPackedScalarReceives( + state, targetClass, payload.getType(), *packedType, channelIds, sourceCoreIds, targetCoreIds, loc); + + state.availableValues[*packedKey][targetClass.id] = received; + return success(); + } + + if (sourceClass.isBatch && targetClass.isBatch) { + if (sourceClass.cpus.size() != targetClass.cpus.size()) + return sourceClass.op->emitError( + "cannot materialize batch communication between equivalence classes of different sizes"); + + SmallVector channelIds; + SmallVector sourceCoreIds; + SmallVector targetCoreIds; + channelIds.reserve(sourceClass.cpus.size()); + sourceCoreIds.reserve(sourceClass.cpus.size()); + targetCoreIds.reserve(targetClass.cpus.size()); + + for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { + channelIds.push_back(state.nextChannelId++); + sourceCoreIds.push_back(static_cast(sourceCpu)); + targetCoreIds.push_back(static_cast(targetClass.cpus[lane])); + } + + appendSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); + Value received = + appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); + + for (ProducerKey key : keys) + state.availableValues[key][targetClass.id] = received; + return success(); + } } LogicalResult @@ -821,207 +1055,50 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val if (!payloadType || !payloadType.hasStaticShape()) return sourceClass.op->emitError("host-facing compute_batch payload must be a static ranked tensor"); + auto laneArg = batch.getLaneArgument(); + if (!laneArg) + return batch.emitOpError("expected compute_batch lane block argument while materializing batch output"); + + auto outputArg = batch.getOutputArgument(resultIndex); + if (!outputArg) + return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); + state.rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front()); + SmallVector offsets; SmallVector sizes; SmallVector strides; offsets.reserve(payloadType.getRank()); sizes.reserve(payloadType.getRank()); strides.reserve(payloadType.getRank()); - auto laneArg = batch.getLaneArgument(); - if (!laneArg) - return batch.emitOpError("expected compute_batch lane block argument while materializing batch output"); + offsets.push_back(*laneArg); sizes.push_back(state.rewriter.getIndexAttr(1)); strides.push_back(state.rewriter.getIndexAttr(1)); + for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) { offsets.push_back(state.rewriter.getIndexAttr(0)); sizes.push_back(state.rewriter.getIndexAttr(payloadType.getDimSize(dim))); strides.push_back(state.rewriter.getIndexAttr(1)); } - auto outputArg = batch.getOutputArgument(resultIndex); - if (!outputArg) - return batch.emitOpError("expected compute_batch output block argument while materializing batch output"); - tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides); return success(); } -void appendScalarSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - int64_t channelId, - int32_t sourceCoreId, - int32_t targetCoreId, - Location loc) { - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - Value channelIdValue = createIndexConstant(state, sourceClass.op, channelId); - Value sourceCoreIdValue = createIndexConstant(state, sourceClass.op, sourceCoreId); - Value targetCoreIdValue = createIndexConstant(state, sourceClass.op, targetCoreId); - SpatChannelSendOp::create(state.rewriter, loc, channelIdValue, sourceCoreIdValue, targetCoreIdValue, payload); -} - -void appendBatchSend(MaterializerState& state, - MaterializedClass& sourceClass, - Value payload, - ArrayRef channelIds, - ArrayRef sourceCoreIds, - ArrayRef targetCoreIds, - Location loc) { - state.rewriter.setInsertionPoint(sourceClass.body->getTerminator()); - SmallVector channelIdValues = createIndexConstants(state, sourceClass.op, channelIds); - SmallVector sourceCoreIdValues = createIndexConstants(state, sourceClass.op, sourceCoreIds); - SmallVector targetCoreIdValues = createIndexConstants(state, sourceClass.op, targetCoreIds); - SpatChannelSendBatchOp::create(state.rewriter, loc, channelIdValues, sourceCoreIdValues, targetCoreIdValues, payload); -} - -LogicalResult emitClassToClassCommunication(MaterializerState& state, - MaterializedClass& sourceClass, - MaterializedClass& targetClass, - ArrayRef keys, - Value payload, - Location loc) { - if (sourceClass.id == targetClass.id) { - for (ProducerKey key : keys) - state.availableValues[key][targetClass.id] = payload; - return success(); - } - - if (!sourceClass.isBatch && !targetClass.isBatch) { - int64_t channelId = state.nextChannelId++; - int32_t sourceCpu = static_cast(sourceClass.cpus.front()); - int32_t targetCpu = static_cast(targetClass.cpus.front()); - appendScalarSend(state, sourceClass, payload, channelId, sourceCpu, targetCpu, loc); - Value received = appendReceive(state, - targetClass, - payload.getType(), - ArrayRef(channelId), - ArrayRef(sourceCpu), - ArrayRef(targetCpu), - loc); - for (ProducerKey key : keys) - state.availableValues[key][targetClass.id] = received; - return success(); - } - - if (!sourceClass.isBatch && targetClass.isBatch) { - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(targetClass.cpus.size()); - sourceCoreIds.reserve(targetClass.cpus.size()); - targetCoreIds.reserve(targetClass.cpus.size()); - - for (CpuId targetCpu : targetClass.cpus) { - int64_t channelId = state.nextChannelId++; - channelIds.push_back(channelId); - sourceCoreIds.push_back(static_cast(sourceClass.cpus.front())); - targetCoreIds.push_back(static_cast(targetCpu)); - appendScalarSend(state, - sourceClass, - payload, - channelId, - static_cast(sourceClass.cpus.front()), - static_cast(targetCpu), - loc); - } - - Value received = - appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); - for (ProducerKey key : keys) - state.availableValues[key][targetClass.id] = received; - return success(); - } - - if (sourceClass.isBatch && !targetClass.isBatch) { - std::optional packedKey = getContiguousProducerKeyForKeys(keys); - if (!packedKey) - return sourceClass.op->emitError("cannot materialize batch-to-scalar communication as concat because source " - "lanes are not contiguous in send order"); - - FailureOr packedType = getPackedBatchTensorType(payload.getType(), keys.size()); - if (failed(packedType)) - return sourceClass.op->emitError( - "cannot materialize batch-to-scalar communication as concat for non-static ranked tensor payload"); - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(sourceClass.cpus.size()); - sourceCoreIds.reserve(sourceClass.cpus.size()); - targetCoreIds.reserve(sourceClass.cpus.size()); - - for (CpuId sourceCpu : sourceClass.cpus) { - channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(static_cast(sourceCpu)); - targetCoreIds.push_back(static_cast(targetClass.cpus.front())); - } - - appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = appendReceive(state, targetClass, *packedType, channelIds, sourceCoreIds, targetCoreIds, loc); - state.availableValues[*packedKey][targetClass.id] = received; - return success(); - } - - if (sourceClass.isBatch && targetClass.isBatch) { - if (sourceClass.cpus.size() != targetClass.cpus.size()) - return sourceClass.op->emitError( - "cannot materialize batch communication between equivalence classes of different sizes"); - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(sourceClass.cpus.size()); - sourceCoreIds.reserve(sourceClass.cpus.size()); - targetCoreIds.reserve(targetClass.cpus.size()); - - for (auto [lane, sourceCpu] : llvm::enumerate(sourceClass.cpus)) { - channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(static_cast(sourceCpu)); - targetCoreIds.push_back(static_cast(targetClass.cpus[lane])); - } - - appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = - appendReceive(state, targetClass, payload.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); - for (ProducerKey key : keys) - state.availableValues[key][targetClass.id] = received; - return success(); - } - - return sourceClass.op->emitError("unhandled materialized communication pattern"); -} - LogicalResult emitHostCommunication(MaterializerState& state, MaterializedClass& sourceClass, ArrayRef keys, Value payload, Value originalOutput, Location loc) { + (void) keys; + (void) loc; + if (!hasLiveExternalUse(originalOutput, state.oldComputeOps)) return success(); - if (!sourceClass.hostOutputs.empty()) - return setHostOutputValue(state, sourceClass, originalOutput, payload); - - SmallVector channelIds; - SmallVector sourceCoreIds; - SmallVector targetCoreIds; - channelIds.reserve(sourceClass.cpus.size()); - sourceCoreIds.reserve(sourceClass.cpus.size()); - targetCoreIds.reserve(sourceClass.cpus.size()); - for (CpuId sourceCpu : sourceClass.cpus) { - channelIds.push_back(state.nextChannelId++); - sourceCoreIds.push_back(static_cast(sourceCpu)); - targetCoreIds.push_back(0); - } - - appendBatchSend(state, sourceClass, payload, channelIds, sourceCoreIds, targetCoreIds, loc); - Value received = - appendHostReceive(state, sourceClass, originalOutput.getType(), channelIds, sourceCoreIds, targetCoreIds, loc); - state.hostReplacements[originalOutput] = received; - return success(); + return setHostOutputValue(state, sourceClass, originalOutput, payload); } LogicalResult emitOutputFanout(MaterializerState& state, @@ -1034,7 +1111,7 @@ LogicalResult emitOutputFanout(MaterializerState& state, return success(); if (!sourceClass.isBatch) { - for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front())) + for (ClassId destinationClass : getDestinationClasses(state, keys.front())) if (failed( emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) return failure(); @@ -1048,7 +1125,7 @@ LogicalResult emitOutputFanout(MaterializerState& state, return sourceClass.op->emitError( "cannot materialize batched output whose lanes have different destination equivalence classes"); - for (ClassId destinationClass : getSortedDestinationClasses(state, keys.front())) + for (ClassId destinationClass : getDestinationClasses(state, keys.front())) if (failed(emitClassToClassCommunication(state, sourceClass, state.classes[destinationClass], keys, payload, loc))) return failure(); diff --git a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp index e9f8f41..ad7abe2 100644 --- a/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp +++ b/src/PIM/Dialect/Spatial/Transforms/MergeComputeNodes/MergeComputeNodesPass.cpp @@ -660,12 +660,12 @@ public: emitMergeIrCounts("after-materialization", func); - if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) { + /*if (failed(runPostMergeCompactionPipeline(func, nextChannelId))) { signalPassFailure(); return; } - emitMergeIrCounts("after-post-merge-compaction", func); + emitMergeIrCounts("after-post-merge-compaction", func);*/ { ScopedMergePhaseTimer timer("cleanup-topological-sort-report"); diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp index d575c84..4e46707 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Constant.cpp @@ -120,7 +120,7 @@ struct FoldConstantCoreMapPattern final : OpRewritePattern { rewriter.setInsertionPoint(mapOp); auto getGlobalOp = memref::GetGlobalOp::create(rewriter, mapOp.getLoc(), initType, globalOp.getName()); - auto sizeInBytes = initType.getNumElements() * initType.getElementTypeBitWidth() / 8; + auto sizeInBytes = getShapedTypeSizeInBytes(initType); pim::PimMemCopyOp::create(rewriter, mapOp.getLoc(), initType, diff --git a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp index 922ae87..14eef8d 100644 --- a/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp +++ b/src/PIM/Pass/PimCodegen/HostConstantFolding/Patterns/Subview.cpp @@ -176,9 +176,9 @@ static LogicalResult rewriteSubviewCopyLikeOp(CopyOp copyOp, if (splitSrc && splitDst && copyShape != ArrayRef(dstSubview->sizes)) return failure(); - const int64_t elementByteWidth = sourceType.getElementTypeBitWidth() / 8; - if (elementByteWidth <= 0) + if (!hasByteSizedElementType(sourceType.getElementType())) return failure(); + const int64_t elementByteWidth = static_cast(getElementTypeSizeInBytes(sourceType.getElementType())); const int64_t totalBytes = getNumElements(copyShape) * elementByteWidth; if (size != totalBytes) diff --git a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp index 735d7de..6b67cee 100644 --- a/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp +++ b/src/PIM/Pass/PimCodegen/MaterializeHostConstantsPass.cpp @@ -31,13 +31,6 @@ static bool isExplicitHostOperand(Operation* op, unsigned operandIndex) { return false; } -static int64_t getValueSizeInBytes(Value value) { - auto type = dyn_cast(value.getType()); - if (!type || !type.hasStaticShape()) - return -1; - return type.getNumElements() * type.getElementTypeBitWidth() / 8; -} - template static void materializeHostConstantsInCore(CoreOpTy coreOp, IRRewriter& rewriter, @@ -82,7 +75,9 @@ static void materializeHostConstantsInCore(CoreOpTy coreOp, continue; } - int64_t totalBytes = getValueSizeInBytes(originalValue); + int64_t totalBytes = -1; + if (auto type = dyn_cast(originalValue.getType()); type && type.hasStaticShape()) + totalBytes = static_cast(getShapedTypeSizeInBytes(type)); if (totalBytes < 0 || !llvm::isInt<32>(totalBytes) || !llvm::isInt<32>(resolvedAddress->byteOffset)) { op->emitOpError("host constant materialization requires 32-bit copy sizes and offsets"); hasFailure = true;