fix much stuff
This commit is contained in:
+1
-7
@@ -32,13 +32,7 @@ function(raptor_write_external_cmake_shim shim_dir external_source_dir descripti
|
||||
\"\${CMAKE_CURRENT_LIST_DIR}/${relative_external_source_dir}\"
|
||||
REALPATH
|
||||
)
|
||||
add_subdirectory(
|
||||
\"\${raptor_external_source_dir}\"
|
||||
\"\${CMAKE_CURRENT_BINARY_DIR}/raptor-external\"
|
||||
)
|
||||
if (DEFINED PIM_ENABLED)
|
||||
set(PIM_ENABLED \"\${PIM_ENABLED}\" PARENT_SCOPE)
|
||||
endif ()
|
||||
include(\"\${raptor_external_source_dir}/CMakeLists.txt\")
|
||||
"
|
||||
)
|
||||
|
||||
|
||||
@@ -21,13 +21,15 @@ namespace {
|
||||
|
||||
template <typename MVMOpTy, typename VMMOpTy, typename ParentOpTy>
|
||||
bool hasMvmVmmWeightUse(ParentOpTy parentOp, unsigned weightIndex) {
|
||||
mlir::Value weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
bool found = false;
|
||||
parentOp.walk([&](mlir::Operation* op) {
|
||||
if (auto mvmOp = mlir::dyn_cast<MVMOpTy>(op))
|
||||
found |= mvmOp.getWeight() == weightArg;
|
||||
found |= mvmOp.getWeight() == *weightArg;
|
||||
else if (auto vmmOp = mlir::dyn_cast<VMMOpTy>(op))
|
||||
found |= vmmOp.getWeight() == weightArg;
|
||||
found |= vmmOp.getWeight() == *weightArg;
|
||||
});
|
||||
return found;
|
||||
}
|
||||
@@ -38,7 +40,8 @@ void walkMvmVmmWeightUses(ParentOpTy parentOp, llvm::function_ref<void(mlir::OpO
|
||||
llvm::SmallSet<unsigned, 8> visited;
|
||||
auto walkWeight = [&](mlir::Value weight) {
|
||||
for (unsigned weightIndex = 0; weightIndex < weights.size(); ++weightIndex) {
|
||||
if (parentOp.getWeightArgument(weightIndex) != weight)
|
||||
auto weightArg = parentOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg || *weightArg != weight)
|
||||
continue;
|
||||
if (visited.insert(weightIndex).second)
|
||||
callback(parentOp->getOpOperand(weightIndex));
|
||||
|
||||
@@ -422,12 +422,17 @@ LogicalResult GemvToSpatialCompute::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
|
||||
SmallVector<Value> vmmOutputs;
|
||||
vmmOutputs.reserve(aHSlices[coreId].size());
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size()))
|
||||
for (auto aHSliceId : llvm::seq<size_t>(0, aHSlices[coreId].size())) {
|
||||
auto weightArg = computeOp.getWeightArgument(aHSliceId);
|
||||
auto inputArg = computeOp.getInputArgument(aHSliceId);
|
||||
if (!weightArg || !inputArg)
|
||||
return failure();
|
||||
vmmOutputs.push_back(spatial::SpatVMMOp::create(rewriter,
|
||||
gemmLoc,
|
||||
currOutHSliceType,
|
||||
computeOp.getWeightArgument(aHSliceId),
|
||||
computeOp.getInputArgument(aHSliceId)));
|
||||
*weightArg,
|
||||
*inputArg));
|
||||
}
|
||||
if (vmmOutputs.empty()) {
|
||||
gemmOp.emitOpError("requires at least one non-empty slice when lowering tiled Gemm to Spatial VMMs");
|
||||
return failure();
|
||||
@@ -561,29 +566,31 @@ LogicalResult GemmToSpatialComputeBatch::matchAndRewrite(ONNXGemmOp gemmOp,
|
||||
rewriter.createBlock(&batchOp.getBody(), batchOp.getBody().end(), TypeRange(blockArgTypes), blockArgLocs);
|
||||
rewriter.setInsertionPointToEnd(body);
|
||||
|
||||
Value lane = batchOp.getLaneArgument();
|
||||
Value weight = batchOp.getWeightArgument(0);
|
||||
Value packedInput = batchOp.getInputArgument(0);
|
||||
Value packedOutput = batchOp.getOutputArgument(0);
|
||||
auto lane = batchOp.getLaneArgument();
|
||||
auto weight = batchOp.getWeightArgument(0);
|
||||
auto packedInput = batchOp.getInputArgument(0);
|
||||
auto packedOutput = batchOp.getOutputArgument(0);
|
||||
if (!lane || !weight || !packedInput || !packedOutput)
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult> inputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> inputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> inputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(aType.getDimSize(1))};
|
||||
SmallVector<OpFoldResult> unitStrides {rewriter.getIndexAttr(1), rewriter.getIndexAttr(1)};
|
||||
Value row =
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, packedInput, inputOffsets, inputSizes, unitStrides)
|
||||
tensor::ExtractSliceOp::create(rewriter, loc, aRowType, *packedInput, inputOffsets, inputSizes, unitStrides)
|
||||
.getResult();
|
||||
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, weight, row).getResult();
|
||||
Value vmmResult = spatial::SpatVMMOp::create(rewriter, loc, outRowType, *weight, row).getResult();
|
||||
Value laneResult = vmmResult;
|
||||
if (sharedBias)
|
||||
laneResult = spatial::SpatVAddOp::create(rewriter, loc, outRowType, vmmResult, sharedBias).getResult();
|
||||
|
||||
auto inParallelOp = spatial::SpatInParallelOp::create(rewriter, loc);
|
||||
rewriter.setInsertionPointToStart(&inParallelOp.getRegion().front());
|
||||
SmallVector<OpFoldResult> outputOffsets {lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputOffsets {*lane, rewriter.getIndexAttr(0)};
|
||||
SmallVector<OpFoldResult> outputSizes {rewriter.getIndexAttr(1), rewriter.getIndexAttr(outType.getDimSize(1))};
|
||||
tensor::ParallelInsertSliceOp::create(
|
||||
rewriter, loc, laneResult, packedOutput, outputOffsets, outputSizes, unitStrides);
|
||||
rewriter, loc, laneResult, *packedOutput, outputOffsets, outputSizes, unitStrides);
|
||||
rewriter.setInsertionPointAfter(batchOp);
|
||||
|
||||
rewriter.replaceOp(gemmOp, batchOp.getResults());
|
||||
|
||||
@@ -27,13 +27,16 @@ static bool canPromoteInputBlockArgument(BlockArgument arg) {
|
||||
return !arg.use_empty() && llvm::all_of(arg.getUsers(), isWeightMaterializationHelperUser);
|
||||
}
|
||||
|
||||
static bool canPromoteInputBlockArgument(std::optional<BlockArgument> arg) {
|
||||
return arg && canPromoteInputBlockArgument(*arg);
|
||||
}
|
||||
|
||||
static bool isDirectConstantValue(Value value) {
|
||||
return isa_and_nonnull<arith::ConstantOp, ONNXConstantOp>(value.getDefiningOp());
|
||||
}
|
||||
|
||||
template <typename ComputeOpTy>
|
||||
static bool hasPromotableWeightLikeInputs(ComputeOpTy compute) {
|
||||
Block& block = compute.getBody().front();
|
||||
for (auto [inputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
if (!isWeightLikeComputeOperand(input))
|
||||
continue;
|
||||
@@ -104,20 +107,30 @@ struct PromoteWeightLikeComputeInputsPattern : OpRewritePattern<spatial::SpatCom
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
|
||||
for (Operation& op : oldBlock.without_terminator())
|
||||
@@ -184,12 +197,15 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
rewriter.getI32IntegerAttr(static_cast<int32_t>(compute.getLaneCount())),
|
||||
newWeights,
|
||||
newInputs);
|
||||
auto laneArg = compute.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch lane block argument");
|
||||
SmallVector<Type> newBlockArgTypes;
|
||||
SmallVector<Location> newBlockArgLocs;
|
||||
newBlockArgTypes.reserve(1 + newWeights.size() + newInputTypes.size() + compute.getNumResults());
|
||||
newBlockArgLocs.reserve(1 + newWeights.size() + newInputLocs.size() + compute.getNumResults());
|
||||
newBlockArgTypes.push_back(compute.getLaneArgument().getType());
|
||||
newBlockArgLocs.push_back(compute.getLaneArgument().getLoc());
|
||||
newBlockArgTypes.push_back(laneArg->getType());
|
||||
newBlockArgLocs.push_back(laneArg->getLoc());
|
||||
for (Value weight : newWeights) {
|
||||
newBlockArgTypes.push_back(weight.getType());
|
||||
newBlockArgLocs.push_back(weight.getLoc());
|
||||
@@ -197,8 +213,11 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
llvm::append_range(newBlockArgTypes, newInputTypes);
|
||||
llvm::append_range(newBlockArgLocs, newInputLocs);
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(compute.getResultTypes())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument");
|
||||
newBlockArgTypes.push_back(resultType);
|
||||
newBlockArgLocs.push_back(compute.getOutputArgument(resultIndex).getLoc());
|
||||
newBlockArgLocs.push_back(outputArg->getLoc());
|
||||
}
|
||||
|
||||
auto* newBlock = rewriter.createBlock(
|
||||
@@ -211,25 +230,41 @@ struct PromoteWeightLikeComputeBatchInputsPattern : OpRewritePattern<spatial::Sp
|
||||
bodyRewriter.setInsertionPointToStart(newBlock);
|
||||
|
||||
IRMapping mapper;
|
||||
mapper.map(compute.getLaneArgument(), newCompute.getLaneArgument());
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
auto newLaneArg = newCompute.getLaneArgument();
|
||||
if (!newLaneArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch lane block argument");
|
||||
mapper.map(*laneArg, *newLaneArg);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg || !newWeightArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch weight block argument during rewrite");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
size_t newInputIdx = 0;
|
||||
for (auto [oldInputIdx, input] : llvm::enumerate(compute.getInputs())) {
|
||||
BlockArgument oldArg = compute.getInputArgument(oldInputIdx);
|
||||
auto oldArg = compute.getInputArgument(oldInputIdx);
|
||||
if (!oldArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch input block argument during rewrite");
|
||||
if (!promoteInput[oldInputIdx]) {
|
||||
mapper.map(oldArg, newCompute.getInputArgument(newInputIdx++));
|
||||
auto newInputArg = newCompute.getInputArgument(newInputIdx++);
|
||||
if (!newInputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing rewritten compute_batch input block argument");
|
||||
mapper.map(*oldArg, *newInputArg);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto clonedValue = materializeWeightLikeValueInBlock(input, bodyRewriter, mapper);
|
||||
if (failed(clonedValue))
|
||||
return rewriter.notifyMatchFailure(compute, "failed to materialize promoted batch weight-like operand");
|
||||
mapper.map(oldArg, *clonedValue);
|
||||
mapper.map(*oldArg, *clonedValue);
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults())) {
|
||||
auto outputArg = compute.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return rewriter.notifyMatchFailure(compute, "missing compute_batch output block argument during rewrite");
|
||||
mapper.map(*outputArg, newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
}
|
||||
for (auto resultIndex : llvm::seq<size_t>(0, compute.getNumResults()))
|
||||
mapper.map(compute.getOutputArgument(resultIndex),
|
||||
newBlock->getArgument(1 + newWeights.size() + newInputs.size() + resultIndex));
|
||||
|
||||
for (Operation& op : oldBlock)
|
||||
rewriter.clone(op, mapper);
|
||||
|
||||
@@ -183,17 +183,15 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
{static_cast<int>(batchWeights.size()), static_cast<int>(batchInputs.size())});
|
||||
coreBatchOp->setAttr(onnx_mlir::kCoreIdsAttrName, rewriter.getDenseI32ArrayAttr(coreIds));
|
||||
|
||||
SmallVector<Value> hostOutputTensors;
|
||||
SmallVector<unsigned> returnOperandIndices;
|
||||
if (computeBatchOp.getNumResults() != 0) {
|
||||
hostOutputTensors.resize(computeBatchOp.getNumResults());
|
||||
returnOperandIndices.resize(computeBatchOp.getNumResults());
|
||||
for (auto [resultIndex, result] : llvm::enumerate(computeBatchOp.getResults())) {
|
||||
FailureOr<unsigned> returnOperandIndex = getDirectReturnOperandIndex(cast<OpResult>(result));
|
||||
if (failed(returnOperandIndex))
|
||||
return computeBatchOp.emitOpError(
|
||||
"resultful compute_batch lowering currently requires each result to be used directly by func.return");
|
||||
|
||||
hostOutputTensors[resultIndex] = outputTensors[*returnOperandIndex](rewriter, loc);
|
||||
result.replaceAllUsesWith(hostOutputTensors[resultIndex]);
|
||||
returnOperandIndices[resultIndex] = *returnOperandIndex;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,11 +207,20 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
|
||||
IRMapping mapper;
|
||||
rewriter.setInsertionPointToStart(newBlock);
|
||||
mapper.map(computeBatchOp.getLaneArgument(), coreBatchOp.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex)
|
||||
mapper.map(computeBatchOp.getWeightArgument(weightIndex), coreBatchOp.getWeightArgument(weightIndex));
|
||||
auto oldLaneArg = computeBatchOp.getLaneArgument();
|
||||
if (!oldLaneArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch lane block argument before lowering");
|
||||
mapper.map(*oldLaneArg, coreBatchOp.getLaneArgument());
|
||||
for (unsigned weightIndex = 0; weightIndex < computeBatchOp.getWeights().size(); ++weightIndex) {
|
||||
auto oldWeightArg = computeBatchOp.getWeightArgument(weightIndex);
|
||||
if (!oldWeightArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch weight block arguments before lowering");
|
||||
mapper.map(*oldWeightArg, coreBatchOp.getWeightArgument(weightIndex));
|
||||
}
|
||||
for (unsigned inputIndex = 0; inputIndex < computeBatchOp.getInputs().size(); ++inputIndex) {
|
||||
BlockArgument oldArg = computeBatchOp.getInputArgument(inputIndex);
|
||||
auto oldArg = computeBatchOp.getInputArgument(inputIndex);
|
||||
if (!oldArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch input block arguments before lowering");
|
||||
BlockArgument newArg = coreBatchOp.getInputArgument(inputIndex);
|
||||
auto newArgType = cast<ShapedType>(newArg.getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, loc, newArgType);
|
||||
@@ -226,7 +233,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
rewriter.getI32IntegerAttr(0),
|
||||
getTensorSizeInBytesAttr(rewriter, newArg))
|
||||
.getOutput();
|
||||
mapper.map(oldArg, copied);
|
||||
mapper.map(*oldArg, copied);
|
||||
}
|
||||
|
||||
auto materializeCapturedTensor = [&](Value capturedTensor) -> Value {
|
||||
@@ -248,13 +255,25 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
return copied;
|
||||
};
|
||||
|
||||
SmallVector<Value> hostOutputTensors(returnOperandIndices.size());
|
||||
auto getOrCreateHostOutputTensor = [&](unsigned resultIndex, Location resultLoc) -> Value {
|
||||
Value& hostOutputTensor = hostOutputTensors[resultIndex];
|
||||
if (hostOutputTensor)
|
||||
return hostOutputTensor;
|
||||
|
||||
hostOutputTensor = outputTensors[returnOperandIndices[resultIndex]](rewriter, resultLoc);
|
||||
return hostOutputTensor;
|
||||
};
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBlock);
|
||||
for (Operation& op : oldBlock) {
|
||||
if (isa<spatial::SpatYieldOp>(op))
|
||||
continue;
|
||||
|
||||
if (auto parallelOp = dyn_cast<spatial::SpatInParallelOp>(op)) {
|
||||
unsigned firstOutputArg = computeBatchOp.getOutputArgument(0).getArgNumber();
|
||||
auto firstOutputArg = computeBatchOp.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
return computeBatchOp.emitOpError("expected compute_batch output block arguments before lowering");
|
||||
for (Operation& nestedOp : parallelOp.getRegion().front()) {
|
||||
auto insertSlice = dyn_cast<tensor::ParallelInsertSliceOp>(&nestedOp);
|
||||
if (!insertSlice)
|
||||
@@ -264,12 +283,12 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeBatchOp(spatial::SpatCompute
|
||||
if (!outputArg || outputArg.getOwner() != &oldBlock)
|
||||
return insertSlice.emitOpError("expected compute_batch output block argument destination");
|
||||
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg;
|
||||
if (resultIndex >= hostOutputTensors.size())
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||
if (resultIndex >= returnOperandIndices.size())
|
||||
return insertSlice.emitOpError("result index out of range while lowering host batch output");
|
||||
|
||||
Value mappedSource = mapper.lookup(insertSlice.getSource());
|
||||
auto hostTarget = hostOutputTensors[resultIndex];
|
||||
Value hostTarget = getOrCreateHostOutputTensor(resultIndex, insertSlice.getLoc());
|
||||
auto hostTargetType = cast<ShapedType>(hostTarget.getType());
|
||||
Value hostTargetOffset = createHostTargetOffset(rewriter, insertSlice, hostTargetType, mapper);
|
||||
Value zeroOffset = arith::ConstantIndexOp::create(rewriter, insertSlice.getLoc(), 0).getResult();
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#include <cassert>
|
||||
|
||||
#include "src/Accelerators/PIM/Conversion/SpatialToPim/ComputeLikeRegionUtils.hpp"
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialOps.hpp"
|
||||
|
||||
@@ -29,9 +31,17 @@ void replaceAndEraseDirectComputeLikeInput(PatternRewriter& rewriter,
|
||||
unsigned inputIndex,
|
||||
Value replacement) {
|
||||
Block& body = owner->getRegion(0).front();
|
||||
BlockArgument bodyArgument = isa<spatial::SpatCompute>(owner)
|
||||
? cast<spatial::SpatCompute>(owner).getInputArgument(inputIndex)
|
||||
: cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
BlockArgument bodyArgument;
|
||||
if (auto compute = dyn_cast<spatial::SpatCompute>(owner)) {
|
||||
auto computeArg = compute.getInputArgument(inputIndex);
|
||||
assert(computeArg && "expected compute input block argument");
|
||||
bodyArgument = *computeArg;
|
||||
}
|
||||
else {
|
||||
auto batchArg = cast<spatial::SpatComputeBatch>(owner).getInputArgument(inputIndex);
|
||||
assert(batchArg && "expected compute_batch input block argument");
|
||||
bodyArgument = *batchArg;
|
||||
}
|
||||
unsigned bodyArgIndex = bodyArgument.getArgNumber();
|
||||
|
||||
rewriter.startOpModification(owner);
|
||||
|
||||
@@ -131,8 +131,12 @@ static bool inlineInputlessHelperComputeForWeightLikeUsers(spatial::SpatCompute
|
||||
|
||||
rewriter.setInsertionPoint(computeOp);
|
||||
IRMapping mapping;
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights()))
|
||||
mapping.map(computeOp.getWeightArgument(weightIndex), weight);
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(computeOp.getWeights())) {
|
||||
auto weightArg = computeOp.getWeightArgument(weightIndex);
|
||||
if (!weightArg)
|
||||
return false;
|
||||
mapping.map(*weightArg, weight);
|
||||
}
|
||||
for (Operation& op : block.without_terminator()) {
|
||||
cloneMappedHelperOperands(&op, mapping, rewriter, constantFolder);
|
||||
Operation* clonedOp = rewriter.clone(op, mapping);
|
||||
@@ -164,31 +168,33 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
auto yieldOp = cast<spatial::SpatYieldOp>(block.getTerminator());
|
||||
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
|
||||
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (!blockArg)
|
||||
return computeOp.emitOpError("expected compute input block arguments during lowering");
|
||||
auto receiveOp = dyn_cast_or_null<spatial::SpatChannelReceiveOp>(input.getDefiningOp());
|
||||
if (receiveOp && !blockArg.use_empty()) {
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
if (receiveOp && !blockArg->use_empty()) {
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveOp.getLoc(), outputType);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, blockArg);
|
||||
auto sizeAttr = getTensorSizeInBytesAttr(rewriter, *blockArg);
|
||||
Value received =
|
||||
PimReceiveOp::create(
|
||||
rewriter, receiveOp.getLoc(), outputBuffer.getType(), outputBuffer, sizeAttr, receiveOp.getSourceCoreId())
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveOp);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto receiveTensorOp = dyn_cast_or_null<spatial::SpatChannelReceiveTensorOp>(input.getDefiningOp());
|
||||
if (receiveTensorOp && !blockArg.use_empty()) {
|
||||
if (receiveTensorOp && !blockArg->use_empty()) {
|
||||
FailureOr<SmallVector<int32_t>> sourceCoreIds = getConstantI32Values(receiveTensorOp.getSourceCoreIds());
|
||||
if (failed(sourceCoreIds))
|
||||
return receiveTensorOp.emitOpError("expected constant sourceCoreIds");
|
||||
for (int32_t& sourceCoreId : *sourceCoreIds)
|
||||
sourceCoreId = translateSpatialCoreIdToPimCoreId(sourceCoreId);
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg.getType());
|
||||
rewriter.setInsertionPoint(getEarliestUserWithinBlock(*blockArg));
|
||||
auto outputType = cast<ShapedType>(blockArg->getType());
|
||||
auto outputBuffer = createEmptyTensorFromShaped(rewriter, receiveTensorOp.getLoc(), outputType);
|
||||
Value received = PimReceiveTensorOp::create(rewriter,
|
||||
receiveTensorOp.getLoc(),
|
||||
@@ -196,7 +202,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
outputBuffer,
|
||||
rewriter.getDenseI32ArrayAttr(*sourceCoreIds))
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(received);
|
||||
blockArg->replaceAllUsesWith(received);
|
||||
markOpToRemove(receiveTensorOp);
|
||||
}
|
||||
}
|
||||
@@ -238,12 +244,14 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
rewriter.setInsertionPointToStart(&block);
|
||||
auto& coreOpBlocks = coreOp.getBody().getBlocks();
|
||||
for (auto [inputIndex, input] : llvm::enumerate(computeOp.getInputs())) {
|
||||
BlockArgument blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (blockArg.use_empty())
|
||||
auto blockArg = computeOp.getInputArgument(inputIndex);
|
||||
if (!blockArg)
|
||||
return computeOp.emitOpError("expected compute input block arguments during input materialization");
|
||||
if (blockArg->use_empty())
|
||||
continue;
|
||||
|
||||
if (auto constantOp = input.getDefiningOp<arith::ConstantOp>()) {
|
||||
blockArg.replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
blockArg->replaceAllUsesWith(getOrCreateHostConstantLike(constantOp, constantFolder));
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -261,7 +269,7 @@ LogicalResult raptor::SpatialToPimPass::lowerComputeOp(spatial::SpatCompute comp
|
||||
input,
|
||||
getTensorSizeInBytesAttr(rewriter, input))
|
||||
.getOutput();
|
||||
blockArg.replaceAllUsesWith(copied);
|
||||
blockArg->replaceAllUsesWith(copied);
|
||||
}
|
||||
if (!computeOp.getInputs().empty())
|
||||
block.eraseArguments(computeOp.getWeights().size(), computeOp.getInputs().size());
|
||||
|
||||
@@ -77,8 +77,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgValue = spatCompute.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatCompute.getBody().front().front());
|
||||
@@ -95,8 +97,10 @@ struct MoveExtractSliceIntoCompute final : OpRewritePattern<mlir::tensor::Extrac
|
||||
if (!inputIndex)
|
||||
return failure();
|
||||
auto BBArgValue = spatComputeBatch.getInputArgument(*inputIndex);
|
||||
if (!BBArgValue)
|
||||
return failure();
|
||||
|
||||
if (BBArgValue.use_empty())
|
||||
if (BBArgValue->use_empty())
|
||||
continue;
|
||||
|
||||
rewriter.setInsertionPoint(&spatComputeBatch.getBody().front().front());
|
||||
|
||||
@@ -43,8 +43,14 @@ def SpatCompute : SpatOp<"compute",
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, SpatCompute>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
@@ -70,10 +76,16 @@ def SpatComputeBatch : SpatOp<"compute_batch",
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
::mlir::BlockArgument getLaneArgument();
|
||||
::mlir::BlockArgument getWeightArgument(unsigned idx);
|
||||
::mlir::BlockArgument getInputArgument(unsigned idx);
|
||||
::mlir::BlockArgument getOutputArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getLaneArgument();
|
||||
std::optional<::mlir::BlockArgument> getWeightArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getInputArgument(unsigned idx);
|
||||
std::optional<::mlir::BlockArgument> getOutputArgument(unsigned idx);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertWeight(unsigned idx, ::mlir::Value weight, ::mlir::Location loc);
|
||||
std::optional<std::tuple<::mlir::Value, ::mlir::BlockArgument>>
|
||||
insertInput(unsigned idx, ::mlir::Value input, ::mlir::Location loc);
|
||||
::mlir::FailureOr<std::tuple<::mlir::OpResult, ::mlir::BlockArgument, SpatComputeBatch>>
|
||||
insertOutput(::mlir::RewriterBase &rewriter, unsigned idx, ::mlir::Type type, ::mlir::Location loc);
|
||||
}];
|
||||
|
||||
let hasVerifier = 1;
|
||||
|
||||
@@ -6,11 +6,81 @@ using namespace mlir;
|
||||
|
||||
namespace onnx_mlir {
|
||||
namespace spatial {
|
||||
namespace {
|
||||
|
||||
BlockArgument SpatCompute::getWeightArgument(unsigned idx) { return getBody().front().getArgument(idx); }
|
||||
std::optional<BlockArgument> getBatchBodyArgument(Region& body, unsigned argIdx) {
|
||||
if (body.empty())
|
||||
return std::nullopt;
|
||||
|
||||
BlockArgument SpatCompute::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(getWeights().size() + idx);
|
||||
Block& block = body.front();
|
||||
if (argIdx >= block.getNumArguments())
|
||||
return std::nullopt;
|
||||
return block.getArgument(argIdx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> insertBatchBodyArgument(Region& body, unsigned argIdx, Type type, Location loc) {
|
||||
if (body.empty())
|
||||
return std::nullopt;
|
||||
return body.insertArgument(argIdx, type, loc);
|
||||
}
|
||||
|
||||
void setComputeOperandSegmentSizes(Operation* op, int32_t weightCount, int32_t inputCount) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
return;
|
||||
}
|
||||
cast<SpatComputeBatch>(op).getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getWeightArgument(unsigned idx) { return getBatchBodyArgument(getBody(), idx); }
|
||||
|
||||
std::optional<BlockArgument> SpatCompute::getInputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), getWeights().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatCompute::insertInput(unsigned idx, Value input, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), weightCount + idx, input.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
|
||||
FailureOr<std::tuple<OpResult, SpatCompute>>
|
||||
SpatCompute::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > getNumResults())
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(getOperation());
|
||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||
auto newCompute = SpatCompute::create(rewriter, getLoc(), TypeRange(resultTypes), getWeights(), getInputs());
|
||||
newCompute->setAttrs((*this)->getAttrs());
|
||||
setComputeOperandSegmentSizes(
|
||||
newCompute.getOperation(), static_cast<int32_t>(newCompute.getWeights().size()), static_cast<int32_t>(newCompute.getInputs().size()));
|
||||
rewriter.inlineRegionBefore(getBody(), newCompute.getBody(), newCompute.getBody().end());
|
||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||
getResult(oldResultIdx).replaceAllUsesWith(newCompute.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||
rewriter.eraseOp(getOperation());
|
||||
return std::make_tuple(cast<OpResult>(newCompute.getResult(idx)), newCompute);
|
||||
}
|
||||
|
||||
void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
@@ -18,42 +88,102 @@ void SpatCompute::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn s
|
||||
return;
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
if (auto weightArg = getWeightArgument(index))
|
||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
if (auto inputArg = getInputArgument(index))
|
||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||
}
|
||||
|
||||
BlockArgument SpatComputeBatch::getLaneArgument() { return getBody().front().getArgument(0); }
|
||||
std::optional<BlockArgument> SpatComputeBatch::getLaneArgument() { return getBatchBodyArgument(getBody(), 0); }
|
||||
|
||||
BlockArgument SpatComputeBatch::getWeightArgument(unsigned idx) { return getBody().front().getArgument(1 + idx); }
|
||||
|
||||
BlockArgument SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + idx);
|
||||
std::optional<BlockArgument> SpatComputeBatch::getWeightArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + idx);
|
||||
}
|
||||
|
||||
BlockArgument SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBody().front().getArgument(1 + getWeights().size() + getInputs().size() + idx);
|
||||
std::optional<BlockArgument> SpatComputeBatch::getInputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<BlockArgument> SpatComputeBatch::getOutputArgument(unsigned idx) {
|
||||
return getBatchBodyArgument(getBody(), 1 + getWeights().size() + getInputs().size() + idx);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertWeight(unsigned idx, Value weight, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(idx, ValueRange {weight});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount + 1), static_cast<int32_t>(inputCount));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), 1 + idx, weight.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(idx), *blockArg);
|
||||
}
|
||||
|
||||
std::optional<std::tuple<Value, BlockArgument>> SpatComputeBatch::insertInput(unsigned idx, Value input, Location loc) {
|
||||
unsigned weightCount = getWeights().size();
|
||||
unsigned inputCount = getInputs().size();
|
||||
getOperation()->insertOperands(weightCount + idx, ValueRange {input});
|
||||
setComputeOperandSegmentSizes(
|
||||
getOperation(), static_cast<int32_t>(weightCount), static_cast<int32_t>(inputCount + 1));
|
||||
auto blockArg = insertBatchBodyArgument(getBody(), 1 + weightCount + idx, input.getType(), loc);
|
||||
if (!blockArg)
|
||||
return std::nullopt;
|
||||
return std::make_tuple(getOperation()->getOperand(weightCount + idx), *blockArg);
|
||||
}
|
||||
|
||||
FailureOr<std::tuple<OpResult, BlockArgument, SpatComputeBatch>>
|
||||
SpatComputeBatch::insertOutput(RewriterBase& rewriter, unsigned idx, Type type, Location loc) {
|
||||
if (idx > getNumResults())
|
||||
return failure();
|
||||
|
||||
rewriter.setInsertionPoint(getOperation());
|
||||
SmallVector<Type> resultTypes(getResultTypes().begin(), getResultTypes().end());
|
||||
resultTypes.insert(resultTypes.begin() + idx, type);
|
||||
auto newBatch =
|
||||
SpatComputeBatch::create(rewriter, getLoc(), TypeRange(resultTypes), getLaneCountAttr(), getWeights(), getInputs());
|
||||
newBatch->setAttrs((*this)->getAttrs());
|
||||
setComputeOperandSegmentSizes(
|
||||
newBatch.getOperation(), static_cast<int32_t>(newBatch.getWeights().size()), static_cast<int32_t>(newBatch.getInputs().size()));
|
||||
rewriter.inlineRegionBefore(getBody(), newBatch.getBody(), newBatch.getBody().end());
|
||||
if (newBatch.getBody().empty()) {
|
||||
rewriter.eraseOp(newBatch);
|
||||
return failure();
|
||||
}
|
||||
auto blockArg = newBatch.getBody().front().insertArgument(
|
||||
1 + newBatch.getWeights().size() + newBatch.getInputs().size() + idx, type, loc);
|
||||
for (unsigned oldResultIdx = 0; oldResultIdx < getNumResults(); ++oldResultIdx)
|
||||
getResult(oldResultIdx).replaceAllUsesWith(newBatch.getResult(oldResultIdx < idx ? oldResultIdx : oldResultIdx + 1));
|
||||
rewriter.eraseOp(getOperation());
|
||||
return std::make_tuple(cast<OpResult>(newBatch.getResult(idx)), blockArg, newBatch);
|
||||
}
|
||||
|
||||
void SpatComputeBatch::getAsmBlockArgumentNames(Region& region, OpAsmSetValueNameFn setNameFn) {
|
||||
if (region.empty())
|
||||
return;
|
||||
|
||||
setNameFn(getLaneArgument(), "lane");
|
||||
if (auto laneArg = getLaneArgument())
|
||||
setNameFn(*laneArg, "lane");
|
||||
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
setNameFn(getWeightArgument(index), ("w" + std::to_string(index)).c_str());
|
||||
if (auto weightArg = getWeightArgument(index))
|
||||
setNameFn(*weightArg, ("w" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
setNameFn(getInputArgument(index), ("in" + std::to_string(index)).c_str());
|
||||
if (auto inputArg = getInputArgument(index))
|
||||
setNameFn(*inputArg, ("in" + std::to_string(index)).c_str());
|
||||
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
auto outputArg = getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
continue;
|
||||
if (index == 0) {
|
||||
setNameFn(getOutputArgument(index), "out");
|
||||
setNameFn(*outputArg, "out");
|
||||
continue;
|
||||
}
|
||||
setNameFn(getOutputArgument(index), ("out" + std::to_string(index)).c_str());
|
||||
setNameFn(*outputArg, ("out" + std::to_string(index)).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,12 +5,15 @@
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/RegionKindInterface.h"
|
||||
#include "mlir/IR/Types.h"
|
||||
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
|
||||
|
||||
#include <map>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
/// Include the auto-generated header files containing the declarations
|
||||
#include "src/Accelerators/PIM/Dialect/Spatial/SpatialDialect.hpp.inc"
|
||||
|
||||
@@ -218,17 +218,26 @@ ParseResult SpatConcatOp::parse(OpAsmParser& parser, OperationState& result) {
|
||||
}
|
||||
|
||||
void SpatCompute::print(OpAsmPrinter& printer) {
|
||||
printer << " ";
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||
auto weightArg = getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||
auto inputArg = getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (auto coreIdAttr = (*this)->getAttrOfType<IntegerAttr>(onnx_mlir::kCoreIdAttrName))
|
||||
@@ -309,29 +318,48 @@ ParseResult SpatCompute::parse(OpAsmParser& parser, OperationState& result) {
|
||||
}
|
||||
|
||||
void SpatComputeBatch::print(OpAsmPrinter& printer) {
|
||||
auto laneArg = getLaneArgument();
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index) {
|
||||
auto weightArg = getWeightArgument(index);
|
||||
if (!weightArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
weightArgs.push_back(*weightArg);
|
||||
}
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index) {
|
||||
auto inputArg = getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
inputArgs.push_back(*inputArg);
|
||||
}
|
||||
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
if (!laneArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
if (getNumResults() != 0) {
|
||||
outputArgs.reserve(getNumResults());
|
||||
for (unsigned index = 0; index < getNumResults(); ++index) {
|
||||
auto outputArg = getOutputArgument(index);
|
||||
if (!outputArg)
|
||||
return printer.printGenericOp(getOperation(), /*printOpName=*/false);
|
||||
outputArgs.push_back(*outputArg);
|
||||
}
|
||||
}
|
||||
|
||||
printer << " ";
|
||||
printer.printOperand(getLaneArgument());
|
||||
printer.printOperand(*laneArg);
|
||||
printer << " = 0 to " << getLaneCount();
|
||||
|
||||
printer << " ";
|
||||
SmallVector<Value> weightArgs;
|
||||
weightArgs.reserve(getWeights().size());
|
||||
for (unsigned index = 0; index < getWeights().size(); ++index)
|
||||
weightArgs.push_back(getWeightArgument(index));
|
||||
printBoundValueList(printer, weightArgs, getWeights(), ListDelimiter::Square);
|
||||
printer << " ";
|
||||
SmallVector<Value> inputArgs;
|
||||
inputArgs.reserve(getInputs().size());
|
||||
for (unsigned index = 0; index < getInputs().size(); ++index)
|
||||
inputArgs.push_back(getInputArgument(index));
|
||||
printBoundValueList(printer, inputArgs, getInputs(), ListDelimiter::Paren);
|
||||
|
||||
if (getNumResults() != 0) {
|
||||
printer << " shared_outs";
|
||||
SmallVector<BlockArgument> outputArgs;
|
||||
outputArgs.reserve(getNumResults());
|
||||
for (unsigned index = 0; index < getNumResults(); ++index)
|
||||
outputArgs.push_back(getOutputArgument(index));
|
||||
printBlockArgumentList(printer, outputArgs);
|
||||
}
|
||||
|
||||
|
||||
@@ -107,8 +107,11 @@ static bool isBatchOutputArgument(SpatComputeBatch batchOp, Value value) {
|
||||
return false;
|
||||
|
||||
unsigned argNumber = blockArg.getArgNumber();
|
||||
unsigned firstOutputArg = batchOp.getOutputArgument(0).getArgNumber();
|
||||
return argNumber >= firstOutputArg && argNumber < firstOutputArg + batchOp.getNumResults();
|
||||
auto firstOutputArg = batchOp.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
return false;
|
||||
unsigned firstOutputArgNumber = firstOutputArg->getArgNumber();
|
||||
return argNumber >= firstOutputArgNumber && argNumber < firstOutputArgNumber + batchOp.getNumResults();
|
||||
}
|
||||
|
||||
static bool isConstantIndexLike(Value value) {
|
||||
@@ -293,10 +296,12 @@ static LogicalResult verifyBatchBody(SpatComputeBatch batchOp, Block& block) {
|
||||
return batchOp.emitError("resultful compute_batch body must terminate with spat.in_parallel");
|
||||
}
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
auto laneArg = batchOp.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return batchOp.emitError("compute_batch body must have a lane block argument");
|
||||
for (auto& bodyOp : block) {
|
||||
if (auto extractSlice = dyn_cast<tensor::ExtractSliceOp>(&bodyOp))
|
||||
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, laneArg, "tensor.extract_slice")))
|
||||
if (failed(verifyStaticUnitStrideExtractSliceOp(extractSlice, *laneArg, "tensor.extract_slice")))
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
@@ -457,12 +462,16 @@ LogicalResult SpatCompute::verify() {
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute body must have weight and input block arguments");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
auto blockArg = getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return emitError("compute weight block argument types must match weight operand types exactly");
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs()))
|
||||
if (getInputArgument(inputIndex).getType() != input.getType())
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
auto blockArg = getInputArgument(inputIndex);
|
||||
if (!blockArg || blockArg->getType() != input.getType())
|
||||
return emitError("compute input block argument types must match input operand types exactly");
|
||||
}
|
||||
|
||||
if (block.mightHaveTerminator()) {
|
||||
auto yieldOp = dyn_cast_or_null<SpatYieldOp>(block.getTerminator());
|
||||
@@ -497,7 +506,7 @@ LogicalResult SpatCompute::verify() {
|
||||
}
|
||||
|
||||
for (unsigned inputIndex = 0; inputIndex < getInputs().size(); ++inputIndex)
|
||||
if (getInputArgument(inputIndex).use_empty())
|
||||
if (auto inputArg = getInputArgument(inputIndex); !inputArg || inputArg->use_empty())
|
||||
return emitError("ComputeOp block argument is not used");
|
||||
if (failed(verifyOnlyConstantExternalValues(this->getOperation(), getBody(), "spat.compute")))
|
||||
return failure();
|
||||
@@ -574,23 +583,28 @@ LogicalResult SpatComputeBatch::verify() {
|
||||
}
|
||||
|
||||
Block& block = getBody().front();
|
||||
if (block.getNumArguments() == 0)
|
||||
return emitError("compute_batch body must have exactly one lane block argument");
|
||||
unsigned expectedArgCount = 1 + getWeights().size() + getInputs().size() + getNumResults();
|
||||
if (block.getNumArguments() != expectedArgCount)
|
||||
return emitError("compute_batch body must have lane, weight, input, and output block arguments");
|
||||
if (!getLaneArgument().getType().isIndex())
|
||||
return emitError("compute_batch body block arguments must match lane, weight, input, and output operands/results");
|
||||
auto laneArg = getLaneArgument();
|
||||
if (!laneArg || !laneArg->getType().isIndex())
|
||||
return emitError("compute_batch first block argument must have index type");
|
||||
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights()))
|
||||
if (getWeightArgument(weightIndex).getType() != weight.getType())
|
||||
for (auto [weightIndex, weight] : llvm::enumerate(getWeights())) {
|
||||
auto blockArg = getWeightArgument(weightIndex);
|
||||
if (!blockArg || blockArg->getType() != weight.getType())
|
||||
return emitError("compute_batch weight block argument types must match weight operand types exactly");
|
||||
}
|
||||
for (auto [inputIndex, input] : llvm::enumerate(getInputs())) {
|
||||
BlockArgument blockArg = getInputArgument(inputIndex);
|
||||
if (blockArg.getType() != input.getType())
|
||||
auto blockArg = getInputArgument(inputIndex);
|
||||
if (!blockArg || blockArg->getType() != input.getType())
|
||||
return emitError("compute_batch input block argument types must match input operand types exactly");
|
||||
}
|
||||
for (auto [resultIndex, resultType] : llvm::enumerate(getResultTypes())) {
|
||||
BlockArgument blockArg = getOutputArgument(resultIndex);
|
||||
if (blockArg.getType() != resultType)
|
||||
auto blockArg = getOutputArgument(resultIndex);
|
||||
if (!blockArg || blockArg->getType() != resultType)
|
||||
return emitError("compute_batch output block argument types must match result types exactly");
|
||||
}
|
||||
|
||||
@@ -608,13 +622,15 @@ LogicalResult SpatInParallelOp::verify() {
|
||||
if (batchOp.getNumResults() == 0)
|
||||
return emitOpError("requires a resultful spat.compute_batch parent");
|
||||
|
||||
BlockArgument laneArg = batchOp.getLaneArgument();
|
||||
auto laneArg = batchOp.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return emitOpError("expected compute_batch lane block argument");
|
||||
for (Operation& op : getRegion().front().getOperations()) {
|
||||
auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(&op);
|
||||
if (!insertSliceOp)
|
||||
return emitOpError("expected only tensor.parallel_insert_slice ops");
|
||||
|
||||
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, laneArg, "tensor.parallel_insert_slice")))
|
||||
if (failed(verifyStaticUnitStrideParallelInsertSliceOp(insertSliceOp, *laneArg, "tensor.parallel_insert_slice")))
|
||||
return failure();
|
||||
|
||||
MutableOperandRange destinations = insertSliceOp.getUpdatedDestinations();
|
||||
|
||||
@@ -432,15 +432,6 @@ LogicalResult collectHostOutputs(MaterializerState& state) {
|
||||
return success();
|
||||
}
|
||||
|
||||
void setOperandSegmentSizes(Operation* op, int weightCount, int inputCount) {
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
compute.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
return;
|
||||
}
|
||||
auto batch = cast<SpatComputeBatch>(op);
|
||||
batch.getProperties().setOperandSegmentSizes({weightCount, inputCount});
|
||||
}
|
||||
|
||||
void createEmptyMaterializedOps(MaterializerState& state) {
|
||||
Location loc = state.func.getLoc();
|
||||
Block& funcBlock = state.func.getBody().front();
|
||||
@@ -529,19 +520,17 @@ BlockArgument appendWeight(MaterializerState& state, MaterializedClass& material
|
||||
materializedClass.weights.push_back(weight);
|
||||
|
||||
if (auto compute = dyn_cast<SpatCompute>(materializedClass.op)) {
|
||||
compute.getWeightsMutable().append(ValueRange(weight));
|
||||
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
|
||||
BlockArgument arg = materializedClass.body->insertArgument(weightIndex, weight.getType(), weight.getLoc());
|
||||
materializedClass.weightArgs[weight] = arg;
|
||||
return arg;
|
||||
auto arg = compute.insertWeight(weightIndex, weight, weight.getLoc());
|
||||
assert(arg && "expected compute body while inserting a weight");
|
||||
materializedClass.weightArgs[weight] = std::get<1>(*arg);
|
||||
return std::get<1>(*arg);
|
||||
}
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(materializedClass.op);
|
||||
batch.getWeightsMutable().append(ValueRange(weight));
|
||||
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
|
||||
BlockArgument arg = materializedClass.body->insertArgument(1 + weightIndex, weight.getType(), weight.getLoc());
|
||||
materializedClass.weightArgs[weight] = arg;
|
||||
return arg;
|
||||
auto arg = batch.insertWeight(weightIndex, weight, weight.getLoc());
|
||||
assert(arg && "expected compute_batch body while inserting a weight argument");
|
||||
materializedClass.weightArgs[weight] = std::get<1>(*arg);
|
||||
return std::get<1>(*arg);
|
||||
}
|
||||
|
||||
BlockArgument appendInput(MaterializerState& state, MaterializedClass& materializedClass, Value input) {
|
||||
@@ -551,17 +540,16 @@ BlockArgument appendInput(MaterializerState& state, MaterializedClass& materiali
|
||||
|
||||
materializedClass.inputs.push_back(input);
|
||||
if (auto compute = dyn_cast<SpatCompute>(materializedClass.op)) {
|
||||
compute.getInputsMutable().append(ValueRange(input));
|
||||
BlockArgument arg = materializedClass.body->addArgument(input.getType(), input.getLoc());
|
||||
materializedClass.inputArgs[input] = arg;
|
||||
auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc());
|
||||
assert(arg && "expected compute body while inserting an input");
|
||||
materializedClass.inputArgs[input] = std::get<1>(*arg);
|
||||
return std::get<1>(*arg);
|
||||
}
|
||||
else {
|
||||
cast<SpatComputeBatch>(materializedClass.op).getInputsMutable().append(ValueRange(input));
|
||||
setOperandSegmentSizes(materializedClass.op, materializedClass.weights.size(), materializedClass.inputs.size());
|
||||
BlockArgument arg = materializedClass.body->insertArgument(
|
||||
materializedClass.body->getNumArguments() - 1, input.getType(), input.getLoc());
|
||||
materializedClass.inputArgs[input] = arg;
|
||||
return arg;
|
||||
if (auto compute = dyn_cast<SpatComputeBatch>(materializedClass.op)) {
|
||||
auto arg = compute.insertInput(materializedClass.inputs.size() - 1, input, input.getLoc());
|
||||
assert(arg && "expected compute_batch body while inserting an input argument");
|
||||
materializedClass.inputArgs[input] = std::get<1>(*arg);
|
||||
return std::get<1>(*arg);
|
||||
}
|
||||
llvm_unreachable("Cannot reach here");
|
||||
}
|
||||
@@ -608,6 +596,8 @@ Value createOriginalLaneValue(MaterializerState& state,
|
||||
return createIndexConstant(state, materializedClass.op, peers.front().laneStart);
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(materializedClass.op);
|
||||
auto laneArg = batch.getLaneArgument();
|
||||
assert(laneArg && "expected materialized compute_batch lane argument");
|
||||
bool identity = true;
|
||||
for (auto [lane, peer] : llvm::enumerate(peers)) {
|
||||
if (peer.laneCount != 1 || peer.laneStart != lane) {
|
||||
@@ -616,7 +606,7 @@ Value createOriginalLaneValue(MaterializerState& state,
|
||||
}
|
||||
}
|
||||
if (identity)
|
||||
return batch.getLaneArgument();
|
||||
return *laneArg;
|
||||
|
||||
bool affineWithBase = true;
|
||||
int64_t base = static_cast<int64_t>(peers.front().laneStart);
|
||||
@@ -628,9 +618,9 @@ Value createOriginalLaneValue(MaterializerState& state,
|
||||
}
|
||||
if (affineWithBase) {
|
||||
if (base == 0)
|
||||
return batch.getLaneArgument();
|
||||
return *laneArg;
|
||||
Value baseValue = createIndexConstant(state, materializedClass.op, base);
|
||||
return arith::AddIOp::create(state.rewriter, loc, batch.getLaneArgument(), baseValue).getResult();
|
||||
return arith::AddIOp::create(state.rewriter, loc, *laneArg, baseValue).getResult();
|
||||
}
|
||||
|
||||
SmallVector<APInt, 8> laneValues;
|
||||
@@ -641,7 +631,7 @@ Value createOriginalLaneValue(MaterializerState& state,
|
||||
auto tableType = RankedTensorType::get({static_cast<int64_t>(peers.size())}, state.rewriter.getIndexType());
|
||||
auto tableAttr = DenseIntElementsAttr::get(tableType, laneValues);
|
||||
Value table = arith::ConstantOp::create(state.rewriter, loc, tableType, tableAttr).getResult();
|
||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {batch.getLaneArgument()}).getResult();
|
||||
return tensor::ExtractOp::create(state.rewriter, loc, table, ValueRange {*laneArg}).getResult();
|
||||
}
|
||||
|
||||
bool hasLiveExternalUse(Value value, const DenseSet<Operation*>& oldComputeOps) {
|
||||
@@ -838,7 +828,10 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
|
||||
offsets.reserve(payloadType.getRank());
|
||||
sizes.reserve(payloadType.getRank());
|
||||
strides.reserve(payloadType.getRank());
|
||||
offsets.push_back(batch.getLaneArgument());
|
||||
auto laneArg = batch.getLaneArgument();
|
||||
if (!laneArg)
|
||||
return batch.emitOpError("expected compute_batch lane block argument while materializing batch output");
|
||||
offsets.push_back(*laneArg);
|
||||
sizes.push_back(state.rewriter.getIndexAttr(1));
|
||||
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||
for (int64_t dim = 1; dim < payloadType.getRank(); ++dim) {
|
||||
@@ -847,8 +840,11 @@ setHostOutputValue(MaterializerState& state, MaterializedClass& sourceClass, Val
|
||||
strides.push_back(state.rewriter.getIndexAttr(1));
|
||||
}
|
||||
|
||||
tensor::ParallelInsertSliceOp::create(
|
||||
state.rewriter, payload.getLoc(), payload, batch.getOutputArgument(resultIndex), offsets, sizes, strides);
|
||||
auto outputArg = batch.getOutputArgument(resultIndex);
|
||||
if (!outputArg)
|
||||
return batch.emitOpError("expected compute_batch output block argument while materializing batch output");
|
||||
|
||||
tensor::ParallelInsertSliceOp::create(state.rewriter, payload.getLoc(), payload, *outputArg, offsets, sizes, strides);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -1136,14 +1132,20 @@ void mapWeights(MaterializerState& state,
|
||||
IRMapping& mapper) {
|
||||
Operation* op = instance.op;
|
||||
if (auto compute = dyn_cast<SpatCompute>(op)) {
|
||||
for (auto [index, weight] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(index), appendWeight(state, targetClass, weight));
|
||||
for (auto [index, weight] : llvm::enumerate(compute.getWeights())) {
|
||||
auto weightArg = compute.getWeightArgument(index);
|
||||
assert(weightArg && "expected compute weight block argument");
|
||||
mapper.map(*weightArg, appendWeight(state, targetClass, weight));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto batch = cast<SpatComputeBatch>(op);
|
||||
for (auto [index, weight] : llvm::enumerate(batch.getWeights()))
|
||||
mapper.map(batch.getWeightArgument(index), appendWeight(state, targetClass, weight));
|
||||
for (auto [index, weight] : llvm::enumerate(batch.getWeights())) {
|
||||
auto weightArg = batch.getWeightArgument(index);
|
||||
assert(weightArg && "expected compute_batch weight block argument");
|
||||
mapper.map(*weightArg, appendWeight(state, targetClass, weight));
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult mapInputs(MaterializerState& state,
|
||||
@@ -1156,7 +1158,10 @@ LogicalResult mapInputs(MaterializerState& state,
|
||||
FailureOr<Value> mapped = resolveInputValue(state, targetClass, input, instance);
|
||||
if (failed(mapped))
|
||||
return compute.emitOpError("failed to resolve materialized compute input");
|
||||
mapper.map(compute.getInputArgument(index), *mapped);
|
||||
auto inputArg = compute.getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return compute.emitOpError("expected compute input block argument while materializing inputs");
|
||||
mapper.map(*inputArg, *mapped);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@@ -1166,7 +1171,10 @@ LogicalResult mapInputs(MaterializerState& state,
|
||||
FailureOr<Value> mapped = resolveInputValue(state, targetClass, input, instance);
|
||||
if (failed(mapped))
|
||||
return batch.emitOpError("failed to resolve materialized compute_batch input");
|
||||
mapper.map(batch.getInputArgument(index), *mapped);
|
||||
auto inputArg = batch.getInputArgument(index);
|
||||
if (!inputArg)
|
||||
return batch.emitOpError("expected compute_batch input block argument while materializing inputs");
|
||||
mapper.map(*inputArg, *mapped);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@@ -1186,8 +1194,10 @@ SmallVector<Value, 4> collectMappedBatchOutputs(SpatComputeBatch batch, IRMappin
|
||||
if (!outputArg || outputArg.getOwner() != &batch.getBody().front())
|
||||
continue;
|
||||
|
||||
unsigned firstOutputArg = batch.getOutputArgument(0).getArgNumber();
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg;
|
||||
auto firstOutputArg = batch.getOutputArgument(0);
|
||||
if (!firstOutputArg)
|
||||
return outputs;
|
||||
unsigned resultIndex = outputArg.getArgNumber() - firstOutputArg->getArgNumber();
|
||||
if (resultIndex >= outputs.size())
|
||||
continue;
|
||||
outputs[resultIndex] = mapper.lookupOrDefault(insert.getSource());
|
||||
@@ -1217,7 +1227,12 @@ cloneInstanceBody(MaterializerState& state, MaterializedClass& targetClass, Arra
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
mapper.map(batch.getLaneArgument(), createOriginalLaneValue(state, targetClass, peers, loc));
|
||||
auto laneArg = batch.getLaneArgument();
|
||||
if (!laneArg) {
|
||||
sourceOp->emitError("expected source compute_batch lane block argument");
|
||||
return failure();
|
||||
}
|
||||
mapper.map(*laneArg, createOriginalLaneValue(state, targetClass, peers, loc));
|
||||
}
|
||||
|
||||
mapWeights(state, targetClass, instance, mapper);
|
||||
|
||||
@@ -223,18 +223,32 @@ void mergeTriviallyConnectedComputes(func::FuncOp funcOp) {
|
||||
newBody->addArgument(input.getType(), loc);
|
||||
|
||||
IRMapping mapper;
|
||||
for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights()))
|
||||
mapper.map(compute.getWeightArgument(weightIndex), newCompute.getWeightArgument(weightIndex));
|
||||
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs()))
|
||||
mapper.map(compute.getInputArgument(inputIndex), newCompute.getInputArgument(inputIndex));
|
||||
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights()))
|
||||
mapper.map(child.getWeightArgument(oldIndex), newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]));
|
||||
for (auto [weightIndex, _] : llvm::enumerate(compute.getWeights())) {
|
||||
auto oldWeightArg = compute.getWeightArgument(weightIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(weightIndex);
|
||||
assert(oldWeightArg && newWeightArg && "expected compute weight block arguments");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
for (auto [inputIndex, _] : llvm::enumerate(compute.getInputs())) {
|
||||
auto oldInputArg = compute.getInputArgument(inputIndex);
|
||||
auto newInputArg = newCompute.getInputArgument(inputIndex);
|
||||
assert(oldInputArg && newInputArg && "expected compute input block arguments");
|
||||
mapper.map(*oldInputArg, *newInputArg);
|
||||
}
|
||||
for (auto [oldIndex, weight] : llvm::enumerate(child.getWeights())) {
|
||||
auto oldWeightArg = child.getWeightArgument(oldIndex);
|
||||
auto newWeightArg = newCompute.getWeightArgument(childWeightToNewIndex[oldIndex]);
|
||||
assert(oldWeightArg && newWeightArg && "expected child compute weight block arguments");
|
||||
mapper.map(*oldWeightArg, *newWeightArg);
|
||||
}
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBody);
|
||||
auto computeYield = cast<spatial::SpatYieldOp>(compute.getBody().front().getTerminator());
|
||||
for (Operation& op : compute.getBody().front().without_terminator())
|
||||
rewriter.clone(op, mapper);
|
||||
mapper.map(child.getInputArgument(childInputIndex), mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
|
||||
auto childInputArg = child.getInputArgument(childInputIndex);
|
||||
assert(childInputArg && "expected child compute input block argument");
|
||||
mapper.map(*childInputArg, mapper.lookupOrDefault(computeYield.getOperand(usedResult)));
|
||||
|
||||
rewriter.setInsertionPointToEnd(newBody);
|
||||
for (auto& op : child.getBody().front())
|
||||
|
||||
@@ -30,7 +30,7 @@ python3 validation/operations/gen_tests.py
|
||||
|
||||
| Test | Directory | A (input) | W (weight) | Output | transB | alpha | beta | Bias | Notes |
|
||||
|---------------|-------------------------|-----------|------------|----------|--------|-------|------|-------|------------------------------|
|
||||
| Default | `gemm/` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Hand-crafted, square weights |
|
||||
| Simple | `gemm/simple` | [10,132] | [132,132] | [10,132] | no | 1 | 1 | no | Square weights |
|
||||
| Non-square | `gemm/non_square` | [4,128] | [128,64] | [4,64] | no | 1 | 1 | no | K != N |
|
||||
| With bias | `gemm/with_bias` | [4,128] | [128,128] | [4,128] | no | 1 | 1 | [128] | Bias vector |
|
||||
| transB | `gemm/transB` | [4,128] | [64,128] | [4,64] | yes | 1 | 1 | no | Transposed weight |
|
||||
|
||||
@@ -185,6 +185,18 @@ def conv_depthwise_grouped():
|
||||
# GEMM tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def gemm_simple():
|
||||
"""Simple GEMM with square weights: [10, 132] @ [132, 132]."""
|
||||
B, K, N = 10, 132, 132
|
||||
W = numpy_helper.from_array(np.random.default_rng(41).uniform(-1, 1, (K, N)).astype(np.float32), name="W")
|
||||
A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [B, K])
|
||||
Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [B, N])
|
||||
node = helper.make_node("Gemm", ["A", "W"], ["Y"])
|
||||
graph = helper.make_graph([node], "gemm_simple", [A], [Y], initializer=[W])
|
||||
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
|
||||
save_model(model, "gemm/simple", "gemm_simple.onnx")
|
||||
|
||||
|
||||
def gemm_non_square():
|
||||
"""GEMM with non-square weight matrix: [B, K] @ [K, N], K != N."""
|
||||
B, K, N = 4, 128, 64
|
||||
@@ -823,6 +835,7 @@ def div_after_gemm():
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Generating GEMM tests:")
|
||||
gemm_simple()
|
||||
gemm_non_square()
|
||||
gemm_with_bias()
|
||||
gemm_transB()
|
||||
|
||||
Reference in New Issue
Block a user