Skip to content

Commit

Permalink
[BACKEND] Add barrier after assert op to avoid race condition (#5035)
Browse files Browse the repository at this point in the history
Add a barrier to avoid a race condition in case an assert is followed by
an op that may trap if the assert condition is true. Since the tensor in
those two operations may have different layout we need to make sure all
the threads are done executing the assert before going to the next op.
  • Loading branch information
ThomasRaoux authored Nov 1, 2024
1 parent f0a6d01 commit d0db12b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
5 changes: 5 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
}
}
llAssert(op, condition, adaptor.getMessage(), rewriter);
// Add a barrier to avoid a race condition in case an assert is followed by
// an op that may trap if the assert condition is true. Since the tensor in
// those two operations may have different layout we need to make sure all
// the threads are done executing the assert before going to the next op.
barrier();
rewriter.eraseOp(op);
return success();
}
Expand Down
2 changes: 2 additions & 0 deletions test/Conversion/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1906,6 +1906,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32}
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32}
// CHECK: llvm.call @__assertfail
// CHECK: nvvm.barrier0
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) {
tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5)
Expand Down

0 comments on commit d0db12b

Please sign in to comment.