Commit: aa5206908510552b6daeed83b2a5954014943133
Parent: fde09cde680e00eeea6ced6a452d9d16a9326209
Author: Randy Palamar
Date: Sat, 25 Oct 2025 21:37:31 -0600
shaders/filter: use LDS to cooperatively load needed samples
also iq calculation can be performed once per sample instead of
each time a thread needs it.
this gives a ~28% speed up for a chirp filtering test.
Diffstat:
4 files changed, 48 insertions(+), 32 deletions(-)
diff --git a/beamformer.c b/beamformer.c
@@ -23,9 +23,9 @@
global f32 dt_for_frame;
-#define FILTER_LOCAL_SIZE_X 64
-#define FILTER_LOCAL_SIZE_Y 1
-#define FILTER_LOCAL_SIZE_Z 1
+#define FILTER_LOCAL_SIZE_X 128
+#define FILTER_LOCAL_SIZE_Y 1
+#define FILTER_LOCAL_SIZE_Z 1
#define DECODE_LOCAL_SIZE_X 4
#define DECODE_LOCAL_SIZE_Y 1
@@ -573,6 +573,7 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
bp->Filter.demodulation_frequency = pb->parameters.demodulation_frequency;
bp->Filter.sampling_frequency = pb->parameters.sampling_frequency / 2;
bp->Filter.decimation_rate = decimation_rate;
+ bp->Filter.sample_count = pb->parameters.sample_count;
if (first) {
bp->Filter.input_channel_stride = pb->parameters.raw_data_dimensions[0] / 2;
@@ -601,6 +602,7 @@ plan_compute_pipeline(BeamformerComputePlan *cp, BeamformerParameterBlock *pb)
bp->Filter.input_channel_stride = sample_count * pb->parameters.acquisition_count;
bp->Filter.input_sample_stride = 1;
bp->Filter.input_transmit_stride = sample_count;
+ bp->Filter.sample_count = sample_count;
}
/* TODO(rnp): filter may need a different dispatch layout */
diff --git a/beamformer.meta b/beamformer.meta
@@ -90,6 +90,7 @@
@BakeInt(OutputChannelStride output_channel_stride )
@BakeInt(OutputSampleStride output_sample_stride )
@BakeInt(OutputTransmitStride output_transmit_stride)
+ @BakeInt(SampleCount sample_count )
@BakeInt(SamplingMode sampling_mode )
@BakeFloat(DemodulationFrequency demodulation_frequency)
@BakeFloat(SamplingFrequency sampling_frequency )
diff --git a/generated/beamformer.meta.c b/generated/beamformer.meta.c
@@ -110,6 +110,7 @@ typedef struct {
u32 output_channel_stride;
u32 output_sample_stride;
u32 output_transmit_stride;
+ u32 sample_count;
u32 sampling_mode;
f32 demodulation_frequency;
f32 sampling_frequency;
@@ -304,6 +305,7 @@ read_only global s8 *beamformer_shader_bake_parameter_names[] = {
s8_comp("OutputChannelStride"),
s8_comp("OutputSampleStride"),
s8_comp("OutputTransmitStride"),
+ s8_comp("SampleCount"),
s8_comp("SamplingMode"),
s8_comp("DemodulationFrequency"),
s8_comp("SamplingFrequency"),
@@ -330,7 +332,7 @@ read_only global s8 *beamformer_shader_bake_parameter_names[] = {
read_only global u8 *beamformer_shader_bake_parameter_is_float[] = {
(u8 []){0, 0, 0, 0, 0, 0, 0, 0},
- (u8 []){0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1},
+ (u8 []){0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1},
(u8 []){0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1},
0,
0,
@@ -339,7 +341,7 @@ read_only global u8 *beamformer_shader_bake_parameter_is_float[] = {
read_only global i32 beamformer_shader_bake_parameter_counts[] = {
8,
- 11,
+ 12,
13,
0,
0,
diff --git a/shaders/filter.glsl b/shaders/filter.glsl
@@ -1,4 +1,6 @@
/* See LICENSE for license details. */
+/* TODO(rnp): bug: this won't filter RF data correctly */
+#define SAMPLE_TYPE vec2
#if DataKind == DataKind_Float32
#define DATA_TYPE vec2
#define RESULT_TYPE_CAST(v) (v)
@@ -26,7 +28,7 @@ layout(std430, binding = 2) writeonly restrict buffer buffer_2 {
};
layout(std430, binding = 3) readonly restrict buffer buffer_3 {
- FILTER_TYPE filter_coefficients[];
+ FILTER_TYPE filter_coefficients[FilterLength];
};
layout(r16i, binding = 1) readonly restrict uniform iimage1D channel_mapping;
@@ -69,15 +71,16 @@ vec2 rotate_iq(vec2 iq, int index)
}
#endif
-vec2 sample_rf(uint index)
+SAMPLE_TYPE sample_rf(uint index)
{
- vec2 result = SAMPLE_TYPE_CAST(in_data[index]);
+ SAMPLE_TYPE result = SAMPLE_TYPE_CAST(in_data[index]);
return result;
}
+shared SAMPLE_TYPE local_samples[FilterLength + gl_WorkGroupSize.x];
+
void main()
{
- uint in_sample = gl_GlobalInvocationID.x * DecimationRate;
uint out_sample = gl_GlobalInvocationID.x;
uint channel = gl_GlobalInvocationID.y;
uint transmit = gl_GlobalInvocationID.z;
@@ -88,32 +91,40 @@ void main()
OutputTransmitStride * transmit +
OutputSampleStride * out_sample;
- int target;
- if (bool(MapChannels)) {
- target = OutputChannelStride / OutputSampleStride;
- } else {
- target = OutputTransmitStride;
- }
-
- if (out_sample < target) {
- target *= DecimationRate;
-
- vec2 result = vec2(0);
- int a_length = target;
- int index = int(in_sample);
+ int thread_index = int(gl_LocalInvocationIndex);
+ int thread_count = int(gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z);
+ /////////////////////////
+ // NOTE: sample caching
+ {
+ int min_sample = DecimationRate * int((gl_WorkGroupID.x + 0) * gl_WorkGroupSize.x) - FilterLength;
+ int max_sample = DecimationRate * int((gl_WorkGroupID.x + 1) * gl_WorkGroupSize.x);
+
+ in_offset += min_sample;
+ int total_samples = max_sample - min_sample;
+ int samples_per_thread = total_samples / thread_count;
+ int leftover_count = total_samples % thread_count;
+ int thread_first_index = samples_per_thread * thread_index + min(thread_index, leftover_count);
+ int thread_last_index = thread_first_index + samples_per_thread + int(thread_index < leftover_count);
const float scale = bool(ComplexFilter) ? 1 : sqrt(2);
-
- for (int j = max(0, index - FilterLength); j < min(index, a_length); j++) {
- vec2 iq = sample_rf(in_offset + j);
- FILTER_TYPE h = filter_coefficients[index - j];
- #if Demodulate
- result += scale * apply_filter(rotate_iq(iq * vec2(1, -1), -j), h);
- #else
- result += apply_filter(iq, h);
- #endif
+ for (int i = thread_first_index; i <= thread_last_index; i++) {
+ SAMPLE_TYPE valid = SAMPLE_TYPE(i + min_sample >= 0);
+ #if Demodulate
+ local_samples[i] = scale * valid * rotate_iq(sample_rf(in_offset + i) * vec2(1, -1), -i);
+ #else
+ local_samples[i] = valid * sample_rf(in_offset + i);
+ #endif
+ }
+ }
+ barrier();
+
+ if (out_sample < SampleCount / DecimationRate) {
+ SAMPLE_TYPE result = SAMPLE_TYPE(0);
+ int offset = DecimationRate * thread_index;
+ for (int j = 0; j < FilterLength; j++) {
+ result += apply_filter(local_samples[offset + j],
+ filter_coefficients[FilterLength - 1 - j]);
}
-
out_data[out_offset] = RESULT_TYPE_CAST(result);
}
}