add support for softmax, resize, split, gather
Some checks failed
Validate Operations / validate-operations (push) Has been cancelled

This commit is contained in:
NiccoloN
2026-04-09 14:25:00 +02:00
parent 86916a8fa0
commit 1a0192d1f9
16 changed files with 560 additions and 8 deletions

View File

@@ -455,4 +455,27 @@ def PimVSigmOp : PimOp<"vsigm", [DestinationStyleOpInterface]> {
}];
}
def PimVSoftmaxOp : PimOp<"vsoftmax", [DestinationStyleOpInterface]> {
let summary = "Softmax over the full input vector";
let arguments = (ins
PimTensor:$input,
PimTensor:$outputBuffer
);
let results = (outs
PimTensor:$output
);
let extraClassDeclaration = [{
mlir::MutableOperandRange getDpsInitsMutable() {
return getOutputBufferMutable();
}
}];
let assemblyFormat = [{
`(` $input `,` $outputBuffer `)` attr-dict `:` `(` type($input) `,` type($outputBuffer) `)` `->` type($output)
}];
}
#endif // PIM_DIALECT_H

View File

@@ -273,6 +273,7 @@ void registerOpBufferizationInterfaces(DialectRegistry& registry) {
PimVReluOp::attachInterface<UnaryDstOpInterface<PimVReluOp>>(*ctx);
PimVTanhOp::attachInterface<UnaryDstOpInterface<PimVTanhOp>>(*ctx);
PimVSigmOp::attachInterface<UnaryDstOpInterface<PimVSigmOp>>(*ctx);
PimVSoftmaxOp::attachInterface<UnaryDstOpInterface<PimVSoftmaxOp>>(*ctx);
});
}

View File

@@ -272,6 +272,22 @@ def SpatSigmoidOp : SpatOp<"sigmoid", []> {
}];
}
def SpatSoftmaxOp : SpatOp<"softmax", []> {
let summary = "Softmax over the full input tensor slice";
let arguments = (ins
SpatTensor:$input
);
let results = (outs
SpatTensor:$output
);
let assemblyFormat = [{
`(` $input `)` attr-dict `:` type($input) `->` type($output)
}];
}
def SpatReluOp : SpatOp<"relu", []> {
let summary = "Element-wise ReLU activation";

View File

@@ -202,9 +202,9 @@ private:
rewriter.clone(op, mapper);
}
for (auto users : oldWeightedCompute->getUsers())
if (auto funcRet = dyn_cast<func::ReturnOp>(users))
funcRet.setOperand(0, newWeightedCompute.getResult(0));
for (auto& use : llvm::make_early_inc_range(oldWeightedCompute->getUses()))
if (isa<func::ReturnOp>(use.getOwner()))
use.assign(newWeightedCompute.getResult(0));
oldToNewComputeMap.insert({oldWeightedCompute, newWeightedCompute});
return {cast<SpatWeightedCompute>(newWeightedCompute), computeValueResults};