#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/raw_os_ostream.h" #include #include #include "src/Accelerators/PIM/Common/PimCommon.hpp" #include "src/Accelerators/PIM/Dialect/Pim/PimOps.hpp" #include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Dialect/ONNX/ONNXOps.hpp" using namespace mlir; namespace onnx_mlir { std::string getOutputDir() { if (outputBaseName.empty() || outputBaseName == "-") return {}; size_t lastSlash = outputBaseName.find_last_of('/'); if (lastSlash == std::string::npos) return "."; return outputBaseName.substr(0, lastSlash); } void createDirectory(const std::string& directory) { std::error_code errorCode; std::filesystem::create_directories(directory, errorCode); assert(!errorCode && ("Failed to create directory: " + errorCode.message()).data()); } void dumpModule(ModuleOp moduleOp, const std::string& name) { std::string outputDir = getOutputDir(); if (outputDir.empty()) return; std::string dialectsDir = outputDir + "/dialects"; createDirectory(dialectsDir); std::fstream file(dialectsDir + "/" + name + ".mlir", std::ios::out); llvm::raw_os_ostream os(file); OpPrintingFlags flags; flags.elideLargeElementsAttrs(); moduleOp.print(os, flags); os.flush(); file.close(); } FailureOr getPimEntryFunc(ModuleOp moduleOp) { if (!moduleOp) return failure(); SmallVector entryPoints(moduleOp.getOps()); if (entryPoints.size() > 1) { moduleOp.emitError("PIM pipeline requires a single ONNX entry point, but found ") << entryPoints.size(); return failure(); } if (!entryPoints.empty()) { auto entryPointAttr = entryPoints.front()->getAttrOfType(ONNXEntryPointOp::getEntryPointFuncAttrName()); if (!entryPointAttr) { entryPoints.front().emitOpError("is missing the entry point function attribute"); return failure(); } auto entryFunc = moduleOp.lookupSymbol(entryPointAttr.getLeafReference().getValue()); if (!entryFunc) { entryPoints.front().emitOpError("references an unknown entry function ") << entryPointAttr.getLeafReference().getValue(); return failure(); } return entryFunc; } if (auto mainGraphFunc = moduleOp.lookupSymbol("main_graph")) return mainGraphFunc; SmallVector nonExternalFuncs; for (auto funcOp : moduleOp.getOps()) if (!funcOp.isExternal()) nonExternalFuncs.push_back(funcOp); if (nonExternalFuncs.size() == 1) return nonExternalFuncs.front(); moduleOp.emitError("could not resolve a unique PIM entry function"); return failure(); } bool hasWeightAlways(Operation* op) { return op && op->getAttr(PimWeightAlwaysAttrName) != nullptr; } void markWeightAlways(Operation* op) { assert(op && "expected valid op"); op->setAttr(PimWeightAlwaysAttrName, UnitAttr::get(op->getContext())); } namespace { template bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) { bool found = false; parentOp.walk([&](Operation* op) { if (auto mvmOp = dyn_cast(op)) found |= mvmOp.getWeightIndex() == weightIndex; else if (auto vmmOp = dyn_cast(op)) found |= vmmOp.getWeightIndex() == weightIndex; }); return found; } template void walkMvmVmmWeightUses(ParentOpTy parentOp, function_ref callback) { auto weights = parentOp.getWeights(); llvm::SmallSet visited; auto walkWeightIndex = [&](unsigned weightIndex) { if (weightIndex < weights.size() && visited.insert(weightIndex).second) callback(parentOp->getOpOperand(weightIndex)); }; parentOp.walk([&](MVMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); parentOp.walk([&](VMMOpTy op) { walkWeightIndex(op.getWeightIndex()); }); } } // namespace bool isSpatialMvmVmmWeightUse(OpOperand& use) { Operation* user = use.getOwner(); unsigned operandIndex = use.getOperandNumber(); auto computeOp = dyn_cast(user); if (!computeOp || operandIndex >= computeOp.getWeights().size()) return false; return hasMvmVmmWeightUse(computeOp, operandIndex); } bool hasOnlySpatialMvmVmmWeightUses(Value value) { SmallPtrSet visited; auto walkUses = [&](Value currentValue, auto& self) -> bool { if (!visited.insert(currentValue).second) return true; if (currentValue.use_empty()) return false; return llvm::all_of(currentValue.getUses(), [&](OpOperand& use) { if (isSpatialMvmVmmWeightUse(use)) return true; Operation* user = use.getOwner(); if (auto extractSliceOp = dyn_cast(user)) return extractSliceOp.getSource() == currentValue && self(extractSliceOp.getResult(), self); if (auto expandShapeOp = dyn_cast(user)) return expandShapeOp.getSrc() == currentValue && self(expandShapeOp.getResult(), self); if (auto collapseShapeOp = dyn_cast(user)) return collapseShapeOp.getSrc() == currentValue && self(collapseShapeOp.getResult(), self); if (auto transposeOp = dyn_cast(user)) return transposeOp.getData() == currentValue && self(transposeOp.getResult(), self); return false; }); }; return walkUses(value, walkUses); } void walkPimMvmVmmWeightUses(Operation* root, function_ref callback) { assert(root && "expected valid root op"); root->walk([&](pim::PimCoreOp coreOp) { walkMvmVmmWeightUses(coreOp, callback); }); root->walk([&](pim::PimCoreBatchOp coreBatchOp) { auto weights = coreBatchOp.getWeights(); for (auto weight : weights) for (OpOperand& use : weight.getUses()) if (use.getOwner() == coreBatchOp.getOperation()) callback(use); }); } memref::GlobalOp lookupGlobalForGetGlobal(ModuleOp moduleOp, memref::GetGlobalOp getGlobalOp) { if (!moduleOp || !getGlobalOp) return {}; return moduleOp.lookupSymbol(getGlobalOp.getName()); } SmallVector computeRowMajorStrides(ArrayRef shape) { SmallVector strides(shape.size(), 1); for (int64_t dim = static_cast(shape.size()) - 2; dim >= 0; --dim) strides[dim] = strides[dim + 1] * shape[dim + 1]; return strides; } SmallVector delinearizeIndex(int64_t linearIndex, ArrayRef shape, ArrayRef strides) { SmallVector indices(shape.size(), 0); for (auto [dim, stride] : llvm::enumerate(strides)) { indices[dim] = linearIndex / stride; linearIndex %= stride; } return indices; } int64_t linearizeIndex(ArrayRef indices, ArrayRef strides) { int64_t linearIndex = 0; for (auto [index, stride] : llvm::zip_equal(indices, strides)) linearIndex += index * stride; return linearIndex; } int64_t getNumElements(ArrayRef shape) { int64_t numElements = 1; for (int64_t dim : shape) numElements *= dim; return numElements; } bool isMemoryContiguous(ArrayRef srcShape, ArrayRef offsets, ArrayRef sizes, ArrayRef strides) { if (std::any_of(strides.begin(), strides.end(), [](int64_t stride) -> bool { return stride != 1; })) return false; auto offsetsAndSizesAndShape = llvm::zip_equal(llvm::make_range(offsets.rbegin(), offsets.rend()), llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstNonZeroOffset = std::find_if( offsetsAndSizesAndShape.begin(), offsetsAndSizesAndShape.end(), [&](auto offsetAndSizeAndShape) -> bool { auto [offset, _size, _dimension] = offsetAndSizeAndShape; return offset != 0; }); if (firstNonZeroOffset != offsetsAndSizesAndShape.end()) { auto [offset, size, dimension] = *firstNonZeroOffset; if (size > dimension - offset) return false; ++firstNonZeroOffset; if (std::any_of(firstNonZeroOffset, offsetsAndSizesAndShape.end(), [](auto offsetAndSizeAndShape) -> bool { auto [_offset, size, _dimension] = offsetAndSizeAndShape; return size != 1; })) return false; } auto sizesAndShape = llvm::zip_equal(llvm::make_range(sizes.rbegin(), sizes.rend()), llvm::make_range(srcShape.rbegin(), srcShape.rend())); auto firstDifferentSize = std::find_if(sizesAndShape.begin(), sizesAndShape.end(), [&](auto sizeAndShape) -> bool { auto [size, dimension] = sizeAndShape; return size != dimension; }); if (firstDifferentSize != sizesAndShape.end()) { ++firstDifferentSize; if (std::any_of(firstDifferentSize, sizesAndShape.end(), [](auto sizeAndShape) -> bool { auto [size, _dimension] = sizeAndShape; return size != 1; })) return false; } return true; } static Value resolveAlias(Value value, const StaticValueKnowledge* knowledge) { if (!knowledge) return value; auto iter = knowledge->aliases.find(value); while (iter != knowledge->aliases.end()) { value = iter->second; iter = knowledge->aliases.find(value); } return value; } // Walks through view-like ops and DPS tied operands to find the "underlying" memref value // behind an scf.for iter-arg. Used both when resolving a contiguous address inside a loop // and when propagating yielded values across iterations during static unrolling. static Value resolveLoopCarriedAliasImpl(Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); if (auto blockArgument = dyn_cast(value)) return value; Operation* definingOp = value.getDefiningOp(); if (!definingOp) return value; if (auto dpsDefiningOp = dyn_cast(definingOp)) { if (auto result = dyn_cast(value)) if (OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(result)) return resolveLoopCarriedAliasImpl(tiedOperand->get(), knowledge); } if (auto castOp = dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(castOp.getSource(), knowledge); if (auto collapseOp = dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(collapseOp.getSrc(), knowledge); if (auto expandOp = dyn_cast(definingOp)) return resolveLoopCarriedAliasImpl(expandOp.getSrc(), knowledge); return value; } static FailureOr resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge); static FailureOr resolveIndexValueImpl(Value value, const StaticValueKnowledge* knowledge) { value = resolveAlias(value, knowledge); if (knowledge) { auto iter = knowledge->indexValues.find(value); if (iter != knowledge->indexValues.end()) return iter->second; } auto constantOp = value.getDefiningOp(); if (constantOp) { if (auto integerAttr = dyn_cast(constantOp.getValue())) return integerAttr.getInt(); } Operation* definingOp = value.getDefiningOp(); if (!definingOp) return failure(); if (auto indexCastOp = dyn_cast(definingOp)) return resolveIndexValueImpl(indexCastOp.getIn(), knowledge); if (auto addOp = dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(addOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(addOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return failure(); return *lhs + *rhs; } if (auto subOp = dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(subOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(subOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return failure(); return *lhs - *rhs; } if (auto mulOp = dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(mulOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(mulOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs)) return failure(); return *lhs * *rhs; } if (auto divOp = dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(divOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(divOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs) || *rhs == 0) return failure(); return static_cast(static_cast(*lhs) / static_cast(*rhs)); } if (auto remOp = dyn_cast(definingOp)) { auto lhs = resolveIndexValueImpl(remOp.getLhs(), knowledge); auto rhs = resolveIndexValueImpl(remOp.getRhs(), knowledge); if (failed(lhs) || failed(rhs) || *rhs == 0) return failure(); return static_cast(static_cast(*lhs) % static_cast(*rhs)); } return failure(); } static FailureOr resolveOpFoldResult(OpFoldResult ofr, const StaticValueKnowledge* knowledge) { if (auto attr = dyn_cast(ofr)) { auto integerAttr = dyn_cast(attr); if (!integerAttr) return failure(); return integerAttr.getInt(); } return resolveIndexValueImpl(cast(ofr), knowledge); } static FailureOr resolveContiguousAddressImpl(Value value, const StaticValueKnowledge* knowledge) { int64_t byteOffset = 0; value = resolveAlias(value, knowledge); while (true) { if (isa(value)) return ResolvedContiguousAddress {value, byteOffset}; Operation* definingOp = value.getDefiningOp(); if (!definingOp) return failure(); if (auto dpsDefiningOp = dyn_cast(definingOp)) { OpOperand* tiedOperand = dpsDefiningOp.getTiedOpOperand(dyn_cast(value)); if (!tiedOperand) return failure(); value = resolveAlias(tiedOperand->get(), knowledge); continue; } if (auto forOp = dyn_cast(definingOp)) { auto result = dyn_cast(value); if (!result) return failure(); // Trace the loop carry back to its underlying memref, then if that memref is the // loop's own iter-arg we know the base comes from the corresponding init arg // (every iteration yields the same backing memory in the DPS sense). auto yieldOp = cast(forOp.getBody()->getTerminator()); Value yieldedValue = resolveLoopCarriedAliasImpl(yieldOp.getOperand(result.getResultNumber()), knowledge); if (auto blockArgument = dyn_cast(yieldedValue)) { if (blockArgument.getOwner() == forOp.getBody() && blockArgument.getArgNumber() > 0 && static_cast(blockArgument.getArgNumber() - 1) < forOp.getInitArgs().size()) { value = resolveAlias(forOp.getInitArgs()[blockArgument.getArgNumber() - 1], knowledge); continue; } } value = yieldedValue; continue; } if (auto subviewOp = dyn_cast(definingOp)) { auto sourceType = dyn_cast(subviewOp.getSource().getType()); auto subviewType = dyn_cast(subviewOp.getType()); if (!sourceType || !subviewType || !sourceType.hasStaticShape() || !subviewType.hasStaticShape()) return failure(); SmallVector offsets; SmallVector sizes; SmallVector strides; offsets.reserve(subviewOp.getMixedOffsets().size()); sizes.reserve(subviewOp.getMixedSizes().size()); strides.reserve(subviewOp.getMixedStrides().size()); for (OpFoldResult offset : subviewOp.getMixedOffsets()) { auto resolvedOffset = resolveOpFoldResult(offset, knowledge); if (failed(resolvedOffset)) return failure(); offsets.push_back(*resolvedOffset); } for (OpFoldResult size : subviewOp.getMixedSizes()) { auto resolvedSize = resolveOpFoldResult(size, knowledge); if (failed(resolvedSize)) return failure(); sizes.push_back(*resolvedSize); } for (OpFoldResult stride : subviewOp.getMixedStrides()) { auto resolvedStride = resolveOpFoldResult(stride, knowledge); if (failed(resolvedStride)) return failure(); strides.push_back(*resolvedStride); } if (!isMemoryContiguous(sourceType.getShape(), offsets, sizes, strides)) return failure(); auto sourceStrides = computeRowMajorStrides(sourceType.getShape()); byteOffset += linearizeIndex(offsets, sourceStrides) * subviewType.getElementTypeBitWidth() / 8; value = resolveAlias(subviewOp.getSource(), knowledge); continue; } if (auto castOp = dyn_cast(definingOp)) { value = resolveAlias(castOp.getSource(), knowledge); continue; } if (auto collapseOp = dyn_cast(definingOp)) { value = resolveAlias(collapseOp.getSrc(), knowledge); continue; } if (auto expandOp = dyn_cast(definingOp)) { value = resolveAlias(expandOp.getSrc(), knowledge); continue; } if (isa(definingOp)) return ResolvedContiguousAddress {value, byteOffset}; return failure(); } } FailureOr resolveIndexValue(Value value) { return resolveIndexValueImpl(value, nullptr); } FailureOr resolveIndexValue(Value value, const StaticValueKnowledge& knowledge) { return resolveIndexValueImpl(value, &knowledge); } FailureOr resolveContiguousAddress(Value value) { return resolveContiguousAddressImpl(value, nullptr); } FailureOr resolveContiguousAddress(Value value, const StaticValueKnowledge& knowledge) { return resolveContiguousAddressImpl(value, &knowledge); } Value resolveLoopCarriedAlias(Value value, const StaticValueKnowledge& knowledge) { return resolveLoopCarriedAliasImpl(value, &knowledge); } bool isCoreStaticAddressOp(Operation* op) { return isa(op); } LogicalResult walkPimCoreBlock(Block& block, const StaticValueKnowledge& knowledge, llvm::function_ref callback) { bool hasFailure = false; for (Operation& op : block) { if (isa(op) || isCoreStaticAddressOp(&op)) continue; if (auto forOp = dyn_cast(op)) { Block& loopBody = forOp.getRegion().front(); auto lowerBound = resolveIndexValue(forOp.getLowerBound(), knowledge); auto upperBound = resolveIndexValue(forOp.getUpperBound(), knowledge); auto step = resolveIndexValue(forOp.getStep(), knowledge); if (failed(lowerBound) || failed(upperBound) || failed(step) || *step <= 0) { forOp.emitOpError("requires statically evaluable scf.for bounds for PIM codegen"); hasFailure = true; continue; } SmallVector iterValues(forOp.getInitArgs().begin(), forOp.getInitArgs().end()); for (int64_t inductionValue = *lowerBound; inductionValue < *upperBound; inductionValue += *step) { StaticValueKnowledge loopKnowledge = knowledge; loopKnowledge.indexValues[forOp.getInductionVar()] = inductionValue; for (auto [iterArg, iterValue] : llvm::zip_equal(forOp.getRegionIterArgs(), iterValues)) loopKnowledge.aliases[iterArg] = iterValue; if (failed(walkPimCoreBlock(loopBody, loopKnowledge, callback))) hasFailure = true; auto yieldOp = cast(loopBody.getTerminator()); for (auto [index, yieldedValue] : llvm::enumerate(yieldOp.getOperands())) iterValues[index] = resolveLoopCarriedAlias(yieldedValue, loopKnowledge); } continue; } if (failed(callback(op, knowledge))) hasFailure = true; } return success(!hasFailure); } } // namespace onnx_mlir