better MaterializeMergeSchedule.cpp (something still broken downstream)
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-05-22 06:56:39 +02:00
parent 6aaf1c0870
commit 43ed3914b8
13 changed files with 1433 additions and 1620 deletions
+41 -7
View File
@@ -28,23 +28,47 @@ static SmallVector<int32_t> getLaneChunkCoreIds(ArrayRef<int32_t> coreIds, size_
return laneCoreIds;
}
static Value getOrCloneCapturedValue(OpBuilder& builder, Block& oldBlock, Value value, IRMapping& mapper) {
if (Value mapped = mapper.lookupOrNull(value))
return mapped;
if (auto blockArgument = dyn_cast<BlockArgument>(value)) {
assert(blockArgument.getOwner() != &oldBlock && "expected block argument to be mapped before cloning");
assert(false && "unexpected captured block argument while scalarizing pim.core_batch");
}
Operation* definingOp = value.getDefiningOp();
assert(definingOp && "expected captured value to be defined by an operation");
assert(definingOp->getBlock() != &oldBlock && "expected in-block value to be mapped before cloning");
for (Value operand : definingOp->getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
Operation* cloned = builder.clone(*definingOp, mapper);
for (auto [originalResult, clonedResult] : llvm::zip(definingOp->getResults(), cloned->getResults()))
mapper.map(originalResult, clonedResult);
return mapper.lookup(value);
}
static void cloneScalarizedLaneBody(OpBuilder& builder,
pim::PimCoreBatchOp coreBatchOp,
unsigned lane,
OperationFolder& constantFolder) {
Block& oldBlock = coreBatchOp.getBody().front();
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
size_t laneCount = static_cast<size_t>(coreBatchOp.getLaneCount());
size_t weightCount = coreBatchOp.getWeights().size();
IRMapping mapper;
for (auto [argIndex, blockArg] : llvm::enumerate(oldBlock.getArguments())) {
if (blockArg.getType().isIndex()) {
mapper.map(blockArg, getOrCreateHostIndexConstant(coreBatchOp, static_cast<int64_t>(lane), constantFolder));
mapper.map(blockArg, getOrCreateHostIndexConstant(anchorOp, static_cast<int64_t>(lane), constantFolder));
continue;
}
if (argIndex <= weightCount) {
mapper.map(blockArg, coreBatchOp.getWeights()[argIndex - 1]);
auto scalarCoreOp = cast<pim::PimCoreOp>(anchorOp);
mapper.map(blockArg, scalarCoreOp.getWeightArgument(argIndex - 1));
continue;
}
@@ -57,8 +81,10 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
if (isa<pim::PimHaltOp>(op))
continue;
for (Value operand : op.getOperands())
(void) getOrCloneCapturedValue(builder, oldBlock, operand, mapper);
if (auto sendBatchOp = dyn_cast<pim::PimSendBatchOp>(op)) {
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
pim::PimSendOp::create(
builder,
sendBatchOp.getLoc(),
@@ -78,7 +104,6 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
}
if (auto receiveBatchOp = dyn_cast<pim::PimReceiveBatchOp>(op)) {
Operation* anchorOp = builder.getInsertionBlock()->getParentOp();
auto scalarReceive = pim::PimReceiveOp::create(
builder,
receiveBatchOp.getLoc(),
@@ -106,8 +131,8 @@ static void cloneScalarizedLaneBody(OpBuilder& builder,
builder,
memcpBatchOp.getLoc(),
memcpBatchOp.getOutput().getType(),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
getOrCreateHostIndexConstant(coreBatchOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getDeviceTargetOffset(), constantFolder),
getOrCreateHostIndexConstant(anchorOp, memcpBatchOp.getHostSourceOffset(), constantFolder),
mapper.lookup(memcpBatchOp.getDeviceTarget()),
mapper.lookup(memcpBatchOp.getHostSource()),
memcpBatchOp.getSizeAttr());
@@ -141,7 +166,16 @@ LogicalResult withScalarCoreFromBatchLanes(pim::PimCoreBatchOp coreBatchOp,
auto scalarCore =
pim::PimCoreOp::create(builder, coreBatchOp.getLoc(), ValueRange(weights), builder.getI32IntegerAttr(coreId));
Block* block = builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end());
SmallVector<Type> weightTypes;
SmallVector<Location> weightLocs;
weightTypes.reserve(weights.size());
weightLocs.reserve(weights.size());
for (Value weight : weights) {
weightTypes.push_back(weight.getType());
weightLocs.push_back(weight.getLoc());
}
Block* block =
builder.createBlock(&scalarCore.getBody(), scalarCore.getBody().end(), TypeRange(weightTypes), weightLocs);
builder.setInsertionPointToEnd(block);
for (unsigned lane : lanes)
cloneScalarizedLaneBody(builder, coreBatchOp, lane, constantFolder);