Skip to content

Instantly share code, notes, and snippets.

@benvanik
Last active June 2, 2025 19:11
Show Gist options
  • Save benvanik/ecc9b37fb2b670ce1ed2fb0d7c694287 to your computer and use it in GitHub Desktop.
Save benvanik/ecc9b37fb2b670ce1ed2fb0d7c694287 to your computer and use it in GitHub Desktop.
PR20855 IR examples
// tools/test/iree-run-module-multi.mlir
func.func public @multi_device_mul(
// Input argument is resident on device_a (tooling default to first device).
%input_a: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}
) -> (
// Output result is expected to be on device_a (though not required).
tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}
) {
// Compute on device_a (input is there).
%constant_a = arith.constant dense<[0.0, 1.0, 2.0, 3.0]> : tensor<4xf32>
%transient_a = arith.mulf %input_a, %constant_a : tensor<4xf32>
// Transfer the result from device_a -> device_b.
%transient_b = flow.tensor.transfer %transient_a : tensor<4xf32> to #hal.device.promise<@device_b>
// Compute on device_b.
%constant_b = arith.constant dense<[4.0, 5.0, 6.0, 7.0]> : tensor<4xf32>
%result_b = arith.mulf %transient_b, %constant_b : tensor<4xf32>
// Transfer the result from device_b -> device_a.
%result_a = flow.tensor.transfer %result_b : tensor<4xf32> to #hal.device.promise<@device_a>
// Return the result on device_a (as required by ABI attr).
func.return %result_a : tensor<4xf32>
}
// we'll need to elide transfers before refine usage
// this is the current behavior today (what would be a discrete CPU/GPU, where we want transfers)
#executable_target_vmvx_bytecode_fb = #hal.executable.target<"vmvx", "vmvx-bytecode-fb", {iree.encoding.resolver = #iree_cpu.vmvx_encoding_layout<>, ukernels = "none"}>
#map = affine_map<(d0) -> (d0)>
#device_target_local_0_ = #hal.device.target<"local", {ordinal = 0 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
#device_target_local_1_ = #hal.device.target<"local", {ordinal = 1 : index}, [#executable_target_vmvx_bytecode_fb]> : !hal.device
module attributes {stream.affinity.default = #hal.device.affinity<@device_a>} {
util.global private @device_a = #device_target_local_0_
util.global private @device_b = #device_target_local_1_
stream.executable private @multi_device_mul_dispatch_0 {
stream.executable.export public @multi_device_mul_dispatch_0_elementwise_4_f32 workgroups() -> (index, index, index) {
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @multi_device_mul_dispatch_0_elementwise_4_f32(%arg0: !stream.binding, %arg1: !stream.binding) {
%cst = arith.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
%c0 = arith.constant 0 : index
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !iree_tensor_ext.dispatch.tensor<readonly:tensor<4xf32>>
%1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4xf32>>
%2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0], sizes = [4], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4xf32>> -> tensor<4xf32>
%3 = tensor.empty() : tensor<4xf32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%2, %cst : tensor<4xf32>, tensor<4xf32>) outs(%3 : tensor<4xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%5 = arith.mulf %in, %in_0 : f32
linalg.yield %5 : f32
} -> tensor<4xf32>
iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [0], sizes = [4], strides = [1] : tensor<4xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4xf32>>
return
}
}
}
stream.executable private @multi_device_mul_dispatch_1 {
stream.executable.export public @multi_device_mul_dispatch_1_elementwise_4_f32 workgroups() -> (index, index, index) {
%x, %y, %z = iree_tensor_ext.dispatch.workgroup_count_from_slice
stream.return %x, %y, %z : index, index, index
}
builtin.module {
func.func @multi_device_mul_dispatch_1_elementwise_4_f32(%arg0: !stream.binding, %arg1: !stream.binding) {
%cst = arith.constant dense<[4.000000e+00, 5.000000e+00, 6.000000e+00, 7.000000e+00]> : tensor<4xf32>
%c0 = arith.constant 0 : index
%0 = stream.binding.subspan %arg0[%c0] : !stream.binding -> !iree_tensor_ext.dispatch.tensor<readonly:tensor<4xf32>>
%1 = stream.binding.subspan %arg1[%c0] : !stream.binding -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4xf32>>
%2 = iree_tensor_ext.dispatch.tensor.load %0, offsets = [0], sizes = [4], strides = [1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<4xf32>> -> tensor<4xf32>
%3 = tensor.empty() : tensor<4xf32>
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%2, %cst : tensor<4xf32>, tensor<4xf32>) outs(%3 : tensor<4xf32>) {
^bb0(%in: f32, %in_0: f32, %out: f32):
%5 = arith.mulf %in, %in_0 : f32
linalg.yield %5 : f32
} -> tensor<4xf32>
iree_tensor_ext.dispatch.tensor.store %4, %1, offsets = [0], sizes = [4], strides = [1] : tensor<4xf32> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<4xf32>>
return
}
}
}
util.func public @multi_device_mul(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "async func @multi_device_mul(%input0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}) -> (%output0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>})", iree.abi.model = "coarse-fences"}} {
%c4 = arith.constant 4 : index
%element_type_f32 = hal.element_type<f32> : i32
%dense_row_major = hal.encoding_type<dense_row_major> : i32
hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%c4]) type(%element_type_f32) encoding(%dense_row_major)
%0 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4xf32> : index
%1 = stream.tensor.import on(#hal.device.affinity<@device_a>) %arg0 : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%0}
%2 = stream.timepoint.import on(#hal.device.affinity<@device_a>) %arg1 : (!hal.fence) => !stream.timepoint
%3 = stream.timepoint.await %2 => %1 : !stream.resource<external>{%0}
%4 = stream.async.transfer %3 : !stream.resource<external>{%0} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%0}
%5 = stream.tensor.dispatch on(#hal.device.affinity<@device_a>) @multi_device_mul_dispatch_0::@multi_device_mul_dispatch_0_elementwise_4_f32(%4) : (tensor<4xf32> in !stream.resource<*>{%0}) -> tensor<4xf32> in !stream.resource<*>{%0}
%6 = stream.async.transfer %5 : !stream.resource<*>{%0} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_b>) !stream.resource<*>{%0}
%7 = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<4xf32> : index
%8 = stream.tensor.dispatch on(#hal.device.affinity<@device_b>) @multi_device_mul_dispatch_1::@multi_device_mul_dispatch_1_elementwise_4_f32(%6) : (tensor<4xf32> in !stream.resource<*>{%0}) -> tensor<4xf32> in !stream.resource<*>{%7}
%9 = stream.async.transfer %8 : !stream.resource<*>{%7} from(#hal.device.affinity<@device_b>) -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%7}
%result, %result_timepoint = stream.timepoint.barrier on(#hal.device.affinity<@device_a>) %9 : !stream.resource<*>{%7} => !stream.timepoint
stream.timepoint.chain_external on(#hal.device.affinity<@device_a>) %result_timepoint => (%arg2 : !hal.fence)
%10 = stream.async.transfer %result : !stream.resource<*>{%7} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<external>{%7}
%11 = stream.tensor.export on(#hal.device.affinity<@device_a>) %10 : tensor<4xf32> in !stream.resource<external>{%7} -> !hal.buffer_view
util.return %11 : !hal.buffer_view
}
}
// once allocated the transient allocation that is used on both devices is allocated
// with an optimal attr while the other resources only used on the same device are not
util.func public @multi_device_mul(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "async func @multi_device_mul(%input0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}) -> (%output0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>})", iree.abi.model = "coarse-fences"}} {
%c64 = arith.constant 64 : index
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c4 = arith.constant 4 : index
%element_type_f32 = hal.element_type<f32> : i32
%dense_row_major = hal.encoding_type<dense_row_major> : i32
hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%c4]) type(%element_type_f32) encoding(%dense_row_major)
%0 = stream.tensor.import on(#hal.device.affinity<@device_a>) %arg0 : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%c16}
%1 = stream.timepoint.import on(#hal.device.affinity<@device_a>) %arg1 : (!hal.fence) => !stream.timepoint
%result, %result_timepoint = stream.resource.alloca uninitialized on(#hal.device.optimal<[#hal.device.affinity<@device_a>, #hal.device.affinity<@device_b>]>) await(%1) => !stream.resource<transient>{%c16} => !stream.timepoint
%result_0, %result_timepoint_1 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device_a>) await(%1) => !stream.resource<transient>{%c64} => !stream.timepoint
%2 = stream.timepoint.join max(%1, %result_timepoint, %result_timepoint_1) => !stream.timepoint
%3 = stream.cmd.execute on(#hal.device.affinity<@device_a>) await(%2) => with(%0 as %arg3: !stream.resource<external>{%c16}, %result as %arg4: !stream.resource<transient>{%c16}, %result_0 as %arg5: !stream.resource<transient>{%c64}) {
stream.cmd.dispatch @multi_device_mul_dispatch_0::@multi_device_mul_dispatch_0_elementwise_4_f32 {
ro %arg3[%c0 for %c16] : !stream.resource<external>{%c16},
wo %arg5[%c0 for %c16] : !stream.resource<transient>{%c64}
}
stream.cmd.copy %arg5[%c0], %arg4[%c0], %c16 : !stream.resource<transient>{%c64} -> !stream.resource<transient>{%c16}
stream.cmd.flush to(#hal.device.affinity<@device_b>) %arg4[%c0 for %c16] : !stream.resource<transient>{%c16}
} => !stream.timepoint
%4 = stream.resource.dealloca on(#hal.device.affinity<@device_a>) await(%3) => %result_0 : !stream.resource<transient>{%c64} => !stream.timepoint
%5 = stream.timepoint.join max(%4, %3) => !stream.timepoint
%result_2, %result_timepoint_3 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device_a>) await(%5) => !stream.resource<external>{%c16} => !stream.timepoint
%result_4, %result_timepoint_5 = stream.resource.alloca uninitialized on(#hal.device.affinity<@device_b>) await(%5) => !stream.resource<transient>{%c64} => !stream.timepoint
%6 = stream.timepoint.join max(%4, %3, %result_timepoint_3, %result_timepoint_5) => !stream.timepoint
%7 = stream.cmd.execute on(#hal.device.affinity<@device_b>) await(%6) => with(%result as %arg3: !stream.resource<transient>{%c16}, %result_2 as %arg4: !stream.resource<external>{%c16}, %result_4 as %arg5: !stream.resource<transient>{%c64}) {
stream.cmd.dispatch @multi_device_mul_dispatch_1::@multi_device_mul_dispatch_1_elementwise_4_f32 {
ro %arg3[%c0 for %c16] : !stream.resource<transient>{%c16},
wo %arg5[%c0 for %c16] : !stream.resource<transient>{%c64}
}
stream.cmd.copy %arg5[%c0], %arg4[%c0], %c16 : !stream.resource<transient>{%c64} -> !stream.resource<external>{%c16}
stream.cmd.flush to(#hal.device.affinity<@device_a>) %arg4[%c0 for %c16] : !stream.resource<external>{%c16}
} => !stream.timepoint
%8 = stream.resource.dealloca on(#hal.device.optimal<[#hal.device.affinity<@device_a>, #hal.device.affinity<@device_b>]>) await(%7) => %result : !stream.resource<transient>{%c16} => !stream.timepoint
%9 = stream.resource.dealloca on(#hal.device.affinity<@device_b>) await(%8) => %result_4 : !stream.resource<transient>{%c64} => !stream.timepoint
%10 = stream.timepoint.join max(%9, %8) => !stream.timepoint
stream.timepoint.chain_external on(#hal.device.affinity<@device_a>) %10 => (%arg2 : !hal.fence)
%11 = stream.tensor.export on(#hal.device.affinity<@device_a>) %result_2 : tensor<4xf32> in !stream.resource<external>{%c16} -> !hal.buffer_view
util.return %11 : !hal.buffer_view
}
}
// looking forward to the transfer elision pass, this removes the device_a->device_b
// transfers (preserving the lifetime changes, which we need) - the pass would likely
// just change the affinities on transfers instead of removing ops and let the
// canonicalizer handle it
util.func public @multi_device_mul(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "async func @multi_device_mul(%input0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}) -> (%output0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>})", iree.abi.model = "coarse-fences"}} {
%c4 = arith.constant 4 : index
%element_type_f32 = hal.element_type<f32> : i32
%dense_row_major = hal.encoding_type<dense_row_major> : i32
hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%c4]) type(%element_type_f32) encoding(%dense_row_major)
%0 = stream.tensor.sizeof on(#hal.device.affinity<@device_a>) tensor<4xf32> : index
%1 = stream.tensor.import on(#hal.device.affinity<@device_a>) %arg0 : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%0}
%2 = stream.timepoint.import on(#hal.device.affinity<@device_a>) %arg1 : (!hal.fence) => !stream.timepoint
%3 = stream.timepoint.await %2 => %1 : !stream.resource<external>{%0}
%4 = stream.async.transfer %3 : !stream.resource<external>{%0} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<*>{%0}
%5 = stream.tensor.dispatch on(#hal.device.affinity<@device_a>) @multi_device_mul_dispatch_0::@multi_device_mul_dispatch_0_elementwise_4_f32(%4) : (tensor<4xf32> in !stream.resource<*>{%0}) -> tensor<4xf32> in !stream.resource<*>{%0}
%7 = stream.tensor.sizeof on(#hal.device.affinity<@device_b>) tensor<4xf32> : index
%8 = stream.tensor.dispatch on(#hal.device.affinity<@device_b>) @multi_device_mul_dispatch_1::@multi_device_mul_dispatch_1_elementwise_4_f32(%5) : (tensor<4xf32> in !stream.resource<*>{%0}) -> tensor<4xf32> in !stream.resource<*>{%7}
%result, %result_timepoint = stream.timepoint.barrier on(#hal.device.affinity<@device_a>) %8 : !stream.resource<*>{%7} => !stream.timepoint
stream.timepoint.chain_external on(#hal.device.affinity<@device_a>) %result_timepoint => (%arg2 : !hal.fence)
%10 = stream.async.transfer %result : !stream.resource<*>{%7} from(#hal.device.affinity<@device_a>) -> to(#hal.device.affinity<@device_a>) !stream.resource<external>{%7}
%11 = stream.tensor.export on(#hal.device.affinity<@device_a>) %10 : tensor<4xf32> in !stream.resource<external>{%7} -> !hal.buffer_view
util.return %11 : !hal.buffer_view
}
// allocated without transfers - the allocation is made optimal for the two devices
// and then used on both without copies (unlike the 2 required above)
util.func public @multi_device_mul(%arg0: !hal.buffer_view, %arg1: !hal.fence, %arg2: !hal.fence) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = "async func @multi_device_mul(%input0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>}) -> (%output0: tensor<4xf32> {iree.abi.affinity = #hal.device.promise<@device_a>})", iree.abi.model = "coarse-fences"}} {
%c0_i32 = arith.constant 0 : i32
%c0 = arith.constant 0 : index
%c16 = arith.constant 16 : index
%c4 = arith.constant 4 : index
%element_type_f32 = hal.element_type<f32> : i32
%dense_row_major = hal.encoding_type<dense_row_major> : i32
hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%c4]) type(%element_type_f32) encoding(%dense_row_major)
%0 = stream.tensor.import on(#hal.device.affinity<@device_a>) %arg0 : !hal.buffer_view -> tensor<4xf32> in !stream.resource<external>{%c16}
%1 = stream.timepoint.import on(#hal.device.affinity<@device_a>) %arg1 : (!hal.fence) => !stream.timepoint
%result, %result_timepoint = stream.resource.alloca uninitialized on(#hal.device.optimal<[#hal.device.affinity<@device_a>, #hal.device.affinity<@device_b>]>) await(%1) => !stream.resource<transient>{%c16} => !stream.timepoint
%2 = stream.cmd.execute on(#hal.device.affinity<@device_a>) await(%result_timepoint) => with(%0 as %arg3: !stream.resource<external>{%c16}, %result as %arg4: !stream.resource<transient>{%c16}) {
stream.cmd.dispatch @multi_device_mul_dispatch_0::@multi_device_mul_dispatch_0_elementwise_4_f32 {
ro %arg3[%c0 for %c16] : !stream.resource<external>{%c16},
wo %arg4[%c0 for %c16] : !stream.resource<transient>{%c16}
}
} => !stream.timepoint
%result_0, %result_timepoint_1 = stream.resource.alloca uninitialized on(#hal.device.optimal<[#hal.device.affinity<@device_a>, #hal.device.affinity<@device_b>]>) await(%2) => !stream.resource<external>{%c16} => !stream.timepoint
%3 = stream.cmd.execute on(#hal.device.affinity<@device_b>) await(%result_timepoint_1) => with(%result as %arg3: !stream.resource<transient>{%c16}, %result_0 as %arg4: !stream.resource<external>{%c16}) {
stream.cmd.dispatch @multi_device_mul_dispatch_1::@multi_device_mul_dispatch_1_elementwise_4_f32 {
ro %arg3[%c0 for %c16] : !stream.resource<transient>{%c16},
wo %arg4[%c0 for %c16] : !stream.resource<external>{%c16}
}
} => !stream.timepoint
%4 = stream.resource.dealloca on(#hal.device.optimal<[#hal.device.affinity<@device_a>, #hal.device.affinity<@device_b>]>) await(%3) => %result : !stream.resource<transient>{%c16} => !stream.timepoint
stream.timepoint.chain_external on(#hal.device.affinity<@device_a>) %4 => (%arg2 : !hal.fence)
%5 = stream.tensor.export on(#hal.device.affinity<@device_a>) %result_0 : tensor<4xf32> in !stream.resource<external>{%c16} -> !hal.buffer_view
util.return %5 : !hal.buffer_view
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment