Skip to content

Instantly share code, notes, and snippets.

@vurtun
Last active April 7, 2025 13:17
Show Gist options
  • Save vurtun/6a4f284f75e3133586b04b6692dac0fd to your computer and use it in GitHub Desktop.
Save vurtun/6a4f284f75e3133586b04b6692dac0fd to your computer and use it in GitHub Desktop.
animation sampling
#include <immintrin.h> // AVX, includes SSE2-SSE4.2
#include <assert.h>
#include <math.h>
#include <stdint.h>
#include <pthread.h>
#include <stdlib.h>
#include <stdio.h>
// Constants
#define MAX_MDLS 256
#define MAX_ANIMS 256
#define BATCH_SIZE 4
#define MAX_JOINTS 256
#define ANIM_TRK_ELM_CNT 6 // x, y, z, scl, rot_x, rot_y, rot_z
#define align32 __attribute__((aligned(32)))
// Animation track types
enum anm_trks {
ANM_TRK_POS, // 3 components (x, y, z)
ANM_TRK_ROT, // 1 quaternion (32-bit encoded)
ANM_TRK_SCL, // 1 component (uniform scale)
ANM_TRK_CNT
};
struct anm {
int frame_cnt; // Total number of frames
float scl[ANM_TRK_CNT]; // Scaling factors for pos, rot, scl
struct {
uint16_t* pos; // 6 shorts: x, y, z, tx, ty, tz
uint32_t* rot; // 8 bytes: x, y, z, w, tx, ty, tz, tw
uint16_t* scl; // 2 shorts: s, ts
} keys;
struct {
align32 int* key_off[ANM_TRK_CNT]; // Per-keyframe, per-joint offsets into keys
int* frame_to_key[ANM_TRK_CNT]; // Maps frame index to keyframe index (sparse keyframes)
int key_cnt[ANM_TRK_CNT]; // Number of keyframes per track
} blks;
};
struct mdl_anm {
int anim;
int frame;
float exact_frame;
};
struct mdl {
int x;
};
struct mdl mdls[MAX_MDLS];
struct anm anms[MAX_ANIMS];
struct mdl_anm mdl_anms[MAX_MDLS];
// SOA structure for decoded data, fixed 256 joints
struct anm_frame {
float pos[2][3][256]; // [frame: 0/1][x,y,z][joint]
float rot[2][3][256]; // [frame: 0/1][x,y,z][joint] (no w)
float scl[2][256]; // [frame: 0/1][joint]
float tan_pos[2][3][256]; // Tangents for position
float tan_rot[2][3][256]; // Tangents for rotation (no tw)
float tan_scl[2][256]; // Tangents for scale
};
static inline void
qdec_avx(__m256* qout_x, __m256* qout_y, __m256* qout_z, __m256* qout_w, const uint32_t* qin) {
const __m256 half = _mm256_set1_ps(0.707106781f);
const __m256 inv_msk = _mm256_set1_ps(1.0f / 511.0f);
const __m256 one = _mm256_set1_ps(1.0f);
const __m256 zero = _mm256_setzero_ps();
__m256i q = _mm256_load_si256((__m256i*)qin);
__m256i top = _mm256_srli_epi32(q, 30);
__m256i mask = _mm256_set1_epi32(511);
__m256 qres[4] = {zero, zero, zero, zero};
__m256 sqr = zero;
// Component 0
__m256i mag0 = _mm256_and_si256(q, mask);
__m256i negbit0 = _mm256_and_si256(_mm256_srli_epi32(q, 9), _mm256_set1_epi32(1));
q = _mm256_srli_epi32(q, 10);
__m256 val0 = _mm256_mul_ps(_mm256_cvtepi32_ps(mag0), inv_msk);
val0 = _mm256_mul_ps(val0, half);
sqr = _mm256_fmadd_ps(val0, val0, sqr);
__m256 sign0 = _mm256_sub_ps(zero, _mm256_cvtepi32_ps(negbit0));
val0 = _mm256_fmadd_ps(sign0, val0, val0);
__m256 mask_x0 = _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(0)));
qres[0] = _mm256_blendv_ps(qres[0], val0, _mm256_andnot_ps(mask_x0, one));
// Component 1
__m256i mag1 = _mm256_and_si256(q, mask);
__m256i negbit1 = _mm256_and_si256(_mm256_srli_epi32(q, 9), _mm256_set1_epi32(1));
q = _mm256_srli_epi32(q, 10);
__m256 val1 = _mm256_mul_ps(_mm256_cvtepi32_ps(mag1), inv_msk);
val1 = _mm256_mul_ps(val1, half);
sqr = _mm256_fmadd_ps(val1, val1, sqr);
__m256 sign1 = _mm256_sub_ps(zero, _mm256_cvtepi32_ps(negbit1));
val1 = _mm256_fmadd_ps(sign1, val1, val1);
__m256 mask_y1 = _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(1)));
qres[1] = _mm256_blendv_ps(qres[1], val1, _mm256_andnot_ps(mask_y1, one));
// Component 2
__m256i mag2 = _mm256_and_si256(q, mask);
__m256i negbit2 = _mm256_and_si256(_mm256_srli_epi32(q, 9), _mm256_set1_epi32(1));
q = _mm256_srli_epi32(q, 10);
__m256 val2 = _mm256_mul_ps(_mm256_cvtepi32_ps(mag2), inv_msk);
val2 = _mm256_mul_ps(val2, half);
sqr = _mm256_fmadd_ps(val2, val2, sqr);
__m256 sign2 = _mm256_sub_ps(zero, _mm256_cvtepi32_ps(negbit2));
val2 = _mm256_fmadd_ps(sign2, val2, val2);
__m256 mask_z2 = _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(2)));
qres[2] = _mm256_blendv_ps(qres[2], val2, _mm256_andnot_ps(mask_z2, one));
// Missing component with rsqrt + rcp
__m256 diff = _mm256_max_ps(_mm256_sub_ps(one, sqr), zero); // 1 - sqr
__m256 root = _mm256_rcp_ps(_mm256_rsqrt_ps(diff)); // ≈ √(1 - sqr)
qres[0] = _mm256_blendv_ps(qres[0], root, _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(0))));
qres[1] = _mm256_blendv_ps(qres[1], root, _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(1))));
qres[2] = _mm256_blendv_ps(qres[2], root, _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(2))));
qres[3] = _mm256_blendv_ps(qres[3], root, _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(3))));
*qout_x = qres[0];
*qout_y = qres[1];
*qout_z = qres[2];
*qout_w = qres[3];
}
static void
anim_decode_frames(struct anm_frame* decoded_data, int start_mdl, int mdl_cnt) {
for (int m = start_mdl; m < start_mdl + mdl_cnt && m < MAX_MDLS; ++m) {
int batch_idx = m - start_mdl;
const struct anm* anm = &anms[mdl_anms[m].anim];
int frame = mdl_anms[m].frame;
int mask = (frame + 1 < anm->frame_cnt);
int frame_next = frame + mask;
int key_idx_pos[2] = {anm->blks.frame_to_key[ANM_TRK_POS][frame], anm->blks.frame_to_key[ANM_TRK_POS][frame_next]};
int key_idx_rot[2] = {anm->blks.frame_to_key[ANM_TRK_ROT][frame], anm->blks.frame_to_key[ANM_TRK_ROT][frame_next]};
int key_idx_scl[2] = {anm->blks.frame_to_key[ANM_TRK_SCL][frame], anm->blks.frame_to_key[ANM_TRK_SCL][frame_next]};
const __m256 pos_scl = _mm256_set1_ps(anm->scl[ANM_TRK_POS]);
const __m256 rot_scl = _mm256_set1_ps(anm->scl[ANM_TRK_ROT]);
const __m256 scl_scl = _mm256_set1_ps(anm->scl[ANM_TRK_SCL]);
for (int j = 0; j < 256; j += 16) {
// Base indices for frame 0 and frame 1, split into lo (0-7) and hi (8-15)
int base_idx_pos_lo[2] = {key_idx_pos[0] * 256 + j, key_idx_pos[1] * 256 + j};
int base_idx_pos_hi[2] = {key_idx_pos[0] * 256 + j + 8, key_idx_pos[1] * 256 + j + 8};
int base_idx_rot_lo[2] = {key_idx_rot[0] * 256 + j, key_idx_rot[1] * 256 + j};
int base_idx_rot_hi[2] = {key_idx_rot[0] * 256 + j + 8, key_idx_rot[1] * 256 + j + 8};
int base_idx_scl_lo[2] = {key_idx_scl[0] * 256 + j, key_idx_scl[1] * 256 + j};
int base_idx_scl_hi[2] = {key_idx_scl[0] * 256 + j + 8, key_idx_scl[1] * 256 + j + 8};
// Prefetch next 16 joints
_mm_prefetch((const char*)&anm->keys.pos[key_idx_pos[0] * 6 + j + 16], _MM_HINT_T0);
_mm_prefetch((const char*)&anm->keys.pos[key_idx_pos[1] * 6 + j + 16], _MM_HINT_T0);
_mm_prefetch((const char*)&anm->keys.rot[key_idx_rot[0] * 8 + j + 16], _MM_HINT_T0);
_mm_prefetch((const char*)&anm->keys.rot[key_idx_rot[1] * 8 + j + 16], _MM_HINT_T0);
_mm_prefetch((const char*)&anm->keys.scl[key_idx_scl[0] * 2 + j + 16], _MM_HINT_T0);
_mm_prefetch((const char*)&anm->keys.scl[key_idx_scl[1] * 2 + j + 16], _MM_HINT_T0);
// Position (16 joints, lo/hi)
__m256i k0_pos_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_POS][base_idx_pos_lo[0]]);
__m256i k0_pos_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_POS][base_idx_pos_hi[0]]);
__m256i k1_pos_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_POS][base_idx_pos_lo[1]]);
__m256i k1_pos_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_POS][base_idx_pos_hi[1]]);
__m256i k0_pos_base_lo = _mm256_mullo_epi32(k0_pos_lo, _mm256_set1_epi32(6));
__m256i k0_pos_base_hi = _mm256_mullo_epi32(k0_pos_hi, _mm256_set1_epi32(6));
__m256i k1_pos_base_lo = _mm256_mullo_epi32(k1_pos_lo, _mm256_set1_epi32(6));
__m256i k1_pos_base_hi = _mm256_mullo_epi32(k1_pos_hi, _mm256_set1_epi32(6));
__m256 pos0_x_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, k0_pos_base_lo, 2)), pos_scl);
__m256 pos0_x_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, k0_pos_base_hi, 2)), pos_scl);
__m256 pos0_y_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(1)), 2)), pos_scl);
__m256 pos0_y_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(1)), 2)), pos_scl);
__m256 pos0_z_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(2)), 2)), pos_scl);
__m256 pos0_z_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(2)), 2)), pos_scl);
__m256 tan0_x_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(3)), 2)), pos_scl);
__m256 tan0_x_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(3)), 2)), pos_scl);
__m256 tan0_y_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(4)), 2)), pos_scl);
__m256 tan0_y_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(4)), 2)), pos_scl);
__m256 tan0_z_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(5)), 2)), pos_scl);
__m256 tan0_z_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(5)), 2)), pos_scl);
__m256 pos1_x_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, k1_pos_base_lo, 2)), pos_scl);
__m256 pos1_x_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, k1_pos_base_hi, 2)), pos_scl);
__m256 pos1_y_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(1)), 2)), pos_scl);
__m256 pos1_y_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(1)), 2)), pos_scl);
__m256 pos1_z_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(2)), 2)), pos_scl);
__m256 pos1_z_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(2)), 2)), pos_scl);
__m256 tan1_x_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(3)), 2)), pos_scl);
__m256 tan1_x_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(3)), 2)), pos_scl);
__m256 tan1_y_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(4)), 2)), pos_scl);
__m256 tan1_y_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(4)), 2)), pos_scl);
__m256 tan1_z_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(5)), 2)), pos_scl);
__m256 tan1_z_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(5)), 2)), pos_scl);
// Scale (2 shorts)
__m256i k0_scl_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_SCL][base_idx_scl_lo[0]]);
__m256i k0_scl_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_SCL][base_idx_scl_hi[0]]);
__m256i k1_scl_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_SCL][base_idx_scl_lo[1]]);
__m256i k1_scl_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_SCL][base_idx_scl_hi[1]]);
__m256i k0_scl_base_lo = _mm256_mullo_epi32(k0_scl_lo, _mm256_set1_epi32(2));
__m256i k0_scl_base_hi = _mm256_mullo_epi32(k0_scl_hi, _mm256_set1_epi32(2));
__m256i k1_scl_base_lo = _mm256_mullo_epi32(k1_scl_lo, _mm256_set1_epi32(2));
__m256i k1_scl_base_hi = _mm256_mullo_epi32(k1_scl_hi, _mm256_set1_epi32(2));
__m256 scl0_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, k0_scl_base_lo, 2)), scl_scl);
__m256 scl0_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, k0_scl_base_hi, 2)), scl_scl);
__m256 tan0_s_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, _mm256_add_epi32(k0_scl_base_lo, _mm256_set1_epi32(1)), 2)), scl_scl);
__m256 tan0_s_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, _mm256_add_epi32(k0_scl_base_hi, _mm256_set1_epi32(1)), 2)), scl_scl);
__m256 scl1_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, k1_scl_base_lo, 2)), scl_scl);
__m256 scl1_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, k1_scl_base_hi, 2)), scl_scl);
__m256 tan1_s_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, _mm256_add_epi32(k1_scl_base_lo, _mm256_set1_epi32(1)), 2)), scl_scl);
__m256 tan1_s_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, _mm256_add_epi32(k1_scl_base_hi, _mm256_set1_epi32(1)), 2)), scl_scl);
// Rotation decoding (16 joints, lo/hi)
__m256i k0_rot_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_ROT][base_idx_rot_lo[0]]);
__m256i k0_rot_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_ROT][base_idx_rot_hi[0]]);
__m256i k1_rot_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_ROT][base_idx_rot_lo[1]]);
__m256i k1_rot_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_ROT][base_idx_rot_hi[1]]);
__m256i k0_rot_base_lo = _mm256_mullo_epi32(k0_rot_lo, _mm256_set1_epi32(8));
__m256i k0_rot_base_hi = _mm256_mullo_epi32(k0_rot_hi, _mm256_set1_epi32(8));
__m256i k1_rot_base_lo = _mm256_mullo_epi32(k1_rot_lo, _mm256_set1_epi32(8));
__m256i k1_rot_base_hi = _mm256_mullo_epi32(k1_rot_hi, _mm256_set1_epi32(8));
align32 uint32_t rot0_data_lo[8], rot0_data_hi[8], rot1_data_lo[8], rot1_data_hi[8];
_mm256_store_si256((__m256i*)rot0_data_lo, _mm256_i32gather_epi32((int*)anm->keys.rot, k0_rot_base_lo, 4));
_mm256_store_si256((__m256i*)rot0_data_hi, _mm256_i32gather_epi32((int*)anm->keys.rot, k0_rot_base_hi, 4));
_mm256_store_si256((__m256i*)rot1_data_lo, _mm256_i32gather_epi32((int*)anm->keys.rot, k1_rot_base_lo, 4));
_mm256_store_si256((__m256i*)rot1_data_hi, _mm256_i32gather_epi32((int*)anm->keys.rot, k1_rot_base_hi, 4));
__m256 rot0_x_lo, rot0_y_lo, rot0_z_lo, rot0_w_lo, rot1_x_lo, rot1_y_lo, rot1_z_lo, rot1_w_lo;
__m256 rot0_x_hi, rot0_y_hi, rot0_z_hi, rot0_w_hi, rot1_x_hi, rot1_y_hi, rot1_z_hi, rot1_w_hi;
qdec_avx(&rot0_x_lo, &rot0_y_lo, &rot0_z_lo, &rot0_w_lo, rot0_data_lo);
qdec_avx(&rot0_x_hi, &rot0_y_hi, &rot0_z_hi, &rot0_w_hi, rot0_data_hi);
qdec_avx(&rot1_x_lo, &rot1_y_lo, &rot1_z_lo, &rot1_w_lo, rot1_data_lo);
qdec_avx(&rot1_x_hi, &rot1_y_hi, &rot1_z_hi, &rot1_w_hi, rot1_data_hi);
_mm256_store_si256((__m256i*)rot0_data_lo, _mm256_i32gather_epi32((int*)anm->keys.rot, _mm256_add_epi32(k0_rot_base_lo, _mm256_set1_epi32(4)), 4));
_mm256_store_si256((__m256i*)rot0_data_hi, _mm256_i32gather_epi32((int*)anm->keys.rot, _mm256_add_epi32(k0_rot_base_hi, _mm256_set1_epi32(4)), 4));
_mm256_store_si256((__m256i*)rot1_data_lo, _mm256_i32gather_epi32((int*)anm->keys.rot, _mm256_add_epi32(k1_rot_base_lo, _mm256_set1_epi32(4)), 4));
_mm256_store_si256((__m256i*)rot1_data_hi, _mm256_i32gather_epi32((int*)anm->keys.rot, _mm256_add_epi32(k1_rot_base_hi, _mm256_set1_epi32(4)), 4));
__m256 tan0_rx_lo, tan0_ry_lo, tan0_rz_lo, tan0_rw_lo, tan1_rx_lo, tan1_ry_lo, tan1_rz_lo, tan1_rw_lo;
__m256 tan0_rx_hi, tan0_ry_hi, tan0_rz_hi, tan0_rw_hi, tan1_rx_hi, tan1_ry_hi, tan1_rz_hi, tan1_rw_hi;
qdec_avx(&tan0_rx_lo, &tan0_ry_lo, &tan0_rz_lo, &tan0_rw_lo, rot0_data_lo);
qdec_avx(&tan0_rx_hi, &tan0_ry_hi, &tan0_rz_hi, &tan0_rw_hi, rot0_data_hi);
qdec_avx(&tan1_rx_lo, &tan1_ry_lo, &tan1_rz_lo, &tan1_rw_lo, rot1_data_lo);
qdec_avx(&tan1_rx_hi, &tan1_ry_hi, &tan1_rz_hi, &tan1_rw_hi, rot1_data_hi);
rot0_x_lo = _mm256_mul_ps(rot0_x_lo, rot_scl);
rot0_y_lo = _mm256_mul_ps(rot0_y_lo, rot_scl);
rot0_z_lo = _mm256_mul_ps(rot0_z_lo, rot_scl);
rot0_x_hi = _mm256_mul_ps(rot0_x_hi, rot_scl);
rot0_y_hi = _mm256_mul_ps(rot0_y_hi, rot_scl);
rot0_z_hi = _mm256_mul_ps(rot0_z_hi, rot_scl);
tan0_rx_lo = _mm256_mul_ps(tan0_rx_lo, rot_scl);
tan0_ry_lo = _mm256_mul_ps(tan0_ry_lo, rot_scl);
tan0_rz_lo = _mm256_mul_ps(tan0_rz_lo, rot_scl);
tan0_rx_hi = _mm256_mul_ps(tan0_rx_hi, rot_scl);
tan0_ry_hi = _mm256_mul_ps(tan0_ry_hi, rot_scl);
tan0_rz_hi = _mm256_mul_ps(tan0_rz_hi, rot_scl);
rot1_x_lo = _mm256_mul_ps(rot1_x_lo, rot_scl);
rot1_y_lo = _mm256_mul_ps(rot1_y_lo, rot_scl);
rot1_z_lo = _mm256_mul_ps(rot1_z_lo, rot_scl);
rot1_x_hi = _mm256_mul_ps(rot1_x_hi, rot_scl);
rot1_y_hi = _mm256_mul_ps(rot1_y_hi, rot_scl);
rot1_z_hi = _mm256_mul_ps(rot1_z_hi, rot_scl);
tan1_rx_lo = _mm256_mul_ps(tan1_rx_lo, rot_scl);
tan1_ry_lo = _mm256_mul_ps(tan1_ry_lo, rot_scl);
tan1_rz_lo = _mm256_mul_ps(tan1_rz_lo, rot_scl);
tan1_rx_hi = _mm256_mul_ps(tan1_rx_hi, rot_scl);
tan1_ry_hi = _mm256_mul_ps(tan1_ry_hi, rot_scl);
tan1_rz_hi = _mm256_mul_ps(tan1_rz_hi, rot_scl);
// Store rotation and tangent outputs (16 joints, lo/hi)
_mm256_store_ps(&decoded_data[batch_idx].pos[0][0][j], pos0_x_lo);
_mm256_store_ps(&decoded_data[batch_idx].pos[0][0][j+8], pos0_x_hi);
_mm256_store_ps(&decoded_data[batch_idx].pos[0][1][j], pos0_y_lo);
_mm256_store_ps(&decoded_data[batch_idx].pos[0][1][j+8], pos0_y_hi);
_mm256_store_ps(&decoded_data[batch_idx].pos[0][2][j], pos0_z_lo);
_mm256_store_ps(&decoded_data[batch_idx].pos[0][2][j+8], pos0_z_hi);
_mm256_store_ps(&decoded_data[batch_idx].pos[1][0][j], pos1_x_lo);
_mm256_store_ps(&decoded_data[batch_idx].pos[1][0][j+8], pos1_x_hi);
_mm256_store_ps(&decoded_data[batch_idx].pos[1][1][j], pos1_y_lo);
_mm256_store_ps(&decoded_data[batch_idx].pos[1][1][j+8], pos1_y_hi);
_mm256_store_ps(&decoded_data[batch_idx].pos[1][2][j], pos1_z_lo);
_mm256_store_ps(&decoded_data[batch_idx].pos[1][2][j+8], pos1_z_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][0][j], tan0_x_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][0][j], tan0_x_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][1][j], tan0_y_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][1][j], tan0_y_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][2][j], tan0_z_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][2][j], tan0_z_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][0][j], tan1_x_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][0][j], tan1_x_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][1][j], tan1_y_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][1][j], tan1_y_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][2][j], tan1_z_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][2][j], tan1_z_hi);
_mm256_store_ps(&decoded_data[batch_idx].rot[0][0][j], rot0_x_lo);
_mm256_store_ps(&decoded_data[batch_idx].rot[0][0][j + 8], rot0_x_hi);
_mm256_store_ps(&decoded_data[batch_idx].rot[0][1][j], rot0_y_lo);
_mm256_store_ps(&decoded_data[batch_idx].rot[0][1][j + 8], rot0_y_hi);
_mm256_store_ps(&decoded_data[batch_idx].rot[0][2][j], rot0_z_lo);
_mm256_store_ps(&decoded_data[batch_idx].rot[0][2][j + 8], rot0_z_hi);
_mm256_store_ps(&decoded_data[batch_idx].rot[1][0][j], rot1_x_lo);
_mm256_store_ps(&decoded_data[batch_idx].rot[1][0][j + 8], rot1_x_hi);
_mm256_store_ps(&decoded_data[batch_idx].rot[1][1][j], rot1_y_lo);
_mm256_store_ps(&decoded_data[batch_idx].rot[1][1][j + 8], rot1_y_hi);
_mm256_store_ps(&decoded_data[batch_idx].rot[1][2][j], rot1_z_lo);
_mm256_store_ps(&decoded_data[batch_idx].rot[1][2][j + 8], rot1_z_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][0][j], tan0_rx_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][0][j + 8], tan0_rx_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][1][j], tan0_ry_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][1][j + 8], tan0_ry_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][2][j], tan0_rz_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][2][j + 8], tan0_rz_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][0][j], tan1_rx_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][0][j + 8], tan1_rx_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][1][j], tan1_ry_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][1][j + 8], tan1_ry_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][2][j], tan1_rz_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][2][j + 8], tan1_rz_hi);
_mm256_store_ps(&decoded_data[batch_idx].scl[0][j], scl0_lo);
_mm256_store_ps(&decoded_data[batch_idx].scl[0][j+8], scl0_hi);
_mm256_store_ps(&decoded_data[batch_idx].scl[1][j], scl1_lo);
_mm256_store_ps(&decoded_data[batch_idx].scl[1][j+8], scl1_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_scl[0][j], tan0_s_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_scl[0][j+8], tan0_s_hi);
_mm256_store_ps(&decoded_data[batch_idx].tan_scl[1][j], tan1_s_lo);
_mm256_store_ps(&decoded_data[batch_idx].tan_scl[1][j+8], tan1_s_hi);
}
}
}
static void
compute_hermite_coeffs(float* restrict t_scalar, float* restrict h_coeffs, int start_mdl, int mdl_cnt) {
assert(mdl_cnt <= BATCH_SIZE); // Ensure we don’t exceed batch size (4)
// Compute t values for each mdl in the batch
for (int m = start_mdl; m < start_mdl + mdl_cnt && m < MAX_MDLS; ++m) {
int idx = m - start_mdl;
const struct anm* anm = &anms[mdl_anms[m].anim];
int frame = mdl_anms[m].frame;
int frame_next = (frame + 1 < anm->frame_cnt ? frame + 1 : frame);
t_scalar[idx] = (mdl_anms[m].exact_frame - (float)frame) / ((float)frame_next - (float)frame);
}
// Vectorized Hermite coefficient computation for 4 mdls
__m128 t = _mm_load_ps(t_scalar); // Load 4 t values
__m128 zero = _mm_setzero_ps();
__m128 one = _mm_set1_ps(1.0f);
__m128 t_clamped = _mm_min_ps(_mm_max_ps(t, zero), one); // Clamp t to [0, 1]
__m128 t2 = _mm_mul_ps(t_clamped, t_clamped); // t²
__m128 t3 = _mm_fmadd_ps(t2, t_clamped, zero); // t³ = t² * t + 0
__m128 two = _mm_set1_ps(2.0f);
__m128 neg_three = _mm_set1_ps(-3.0f);
// Compute Hermite basis functions
__m128 h00 = _mm_fmadd_ps(two, t3, _mm_fmadd_ps(neg_three, t2, one)); // 2t³ - 3t² + 1
__m128 h01 = _mm_fmadd_ps(_mm_set1_ps(-2.0f), t3, _mm_mul_ps(_mm_set1_ps(3.0f), t2)); // -2t³ + 3t²
__m128 h10 = _mm_fmadd_ps(t3, one, _mm_fmadd_ps(t2, _mm_set1_ps(-2.0f), t_clamped)); // t³ - 2t² + t
__m128 h11 = _mm_fmsub_ps(t3, one, t2); // t³ - t²
// Store coefficients grouped by type
_mm_store_ps(&h_coeffs[0], h00); // h00_m0-m3
_mm_store_ps(&h_coeffs[4], h01); // h01_m0-m3
_mm_store_ps(&h_coeffs[8], h10); // h10_m0-m3
_mm_store_ps(&h_coeffs[12], h11); // h11_m0-m3
}
static void
anim_interpolate_frames(const struct anm_frame* decoded_data, float* out, int start_mdl, int mdl_cnt, const float* restrict h_coeffs) {
float* out_ptr = out;
// Process up to 4 mdls (BATCH_SIZE = 4)
for (int m = start_mdl; m < start_mdl + mdl_cnt && m < MAX_MDLS; ++m) {
int batch_idx = m - start_mdl;
// Load Hermite coefficients for this mdl (broadcast to all 8 lanes)
__m256 h00 = _mm256_set1_ps(h_coeffs[0 + batch_idx]); // h00 block
__m256 h01 = _mm256_set1_ps(h_coeffs[4 + batch_idx]); // h01 block
__m256 h10 = _mm256_set1_ps(h_coeffs[8 + batch_idx]); // h10 block
__m256 h11 = _mm256_set1_ps(h_coeffs[12 + batch_idx]); // h11 block
for (int j = 0; j < 256; j += 16) {
// Load decoded data for 16 joints (lo: 0-7, hi: 8-15) for this mdl
__m256 pos0_x_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[0][0][j]);
__m256 pos0_x_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[0][0][j + 8]);
__m256 pos0_y_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[0][1][j]);
__m256 pos0_y_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[0][1][j + 8]);
__m256 pos0_z_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[0][2][j]);
__m256 pos0_z_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[0][2][j + 8]);
__m256 pos1_x_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[1][0][j]);
__m256 pos1_x_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[1][0][j + 8]);
__m256 pos1_y_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[1][1][j]);
__m256 pos1_y_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[1][1][j + 8]);
__m256 pos1_z_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[1][2][j]);
__m256 pos1_z_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[1][2][j + 8]);
__m256 tan0_x_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][0][j]);
__m256 tan0_x_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][0][j + 8]);
__m256 tan0_y_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][1][j]);
__m256 tan0_y_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][1][j + 8]);
__m256 tan0_z_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][2][j]);
__m256 tan0_z_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][2][j + 8]);
__m256 tan1_x_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][0][j]);
__m256 tan1_x_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][0][j + 8]);
__m256 tan1_y_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][1][j]);
__m256 tan1_y_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][1][j + 8]);
__m256 tan1_z_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][2][j]);
__m256 tan1_z_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][2][j + 8]);
__m256 rot0_x_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[0][0][j]);
__m256 rot0_x_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[0][0][j + 8]);
__m256 rot0_y_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[0][1][j]);
__m256 rot0_y_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[0][1][j + 8]);
__m256 rot0_z_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[0][2][j]);
__m256 rot0_z_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[0][2][j + 8]);
__m256 rot1_x_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[1][0][j]);
__m256 rot1_x_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[1][0][j + 8]);
__m256 rot1_y_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[1][1][j]);
__m256 rot1_y_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[1][1][j + 8]);
__m256 rot1_z_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[1][2][j]);
__m256 rot1_z_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[1][2][j + 8]);
__m256 tan0_rx_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][0][j]);
__m256 tan0_rx_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][0][j + 8]);
__m256 tan0_ry_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][1][j]);
__m256 tan0_ry_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][1][j + 8]);
__m256 tan0_rz_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][2][j]);
__m256 tan0_rz_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][2][j + 8]);
__m256 tan1_rx_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][0][j]);
__m256 tan1_rx_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][0][j + 8]);
__m256 tan1_ry_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][1][j]);
__m256 tan1_ry_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][1][j + 8]);
__m256 tan1_rz_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][2][j]);
__m256 tan1_rz_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][2][j + 8]);
__m256 scl0_lo = _mm256_load_ps(&decoded_data[batch_idx].scl[0][j]);
__m256 scl0_hi = _mm256_load_ps(&decoded_data[batch_idx].scl[0][j + 8]);
__m256 scl1_lo = _mm256_load_ps(&decoded_data[batch_idx].scl[1][j]);
__m256 scl1_hi = _mm256_load_ps(&decoded_data[batch_idx].scl[1][j + 8]);
__m256 tan0_s_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_scl[0][j]);
__m256 tan0_s_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_scl[0][j + 8]);
__m256 tan1_s_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_scl[1][j]);
__m256 tan1_s_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_scl[1][j + 8]);
// Hermite interpolation: result = h00*p0 + h01*p1 + h10*t0 + h11*t1
__m256 result_x_lo = _mm256_fmadd_ps(h00, pos0_x_lo, _mm256_fmadd_ps(h01, pos1_x_lo, _mm256_fmadd_ps(h10, tan0_x_lo, _mm256_mul_ps(h11, tan1_x_lo))));
__m256 result_x_hi = _mm256_fmadd_ps(h00, pos0_x_hi, _mm256_fmadd_ps(h01, pos1_x_hi, _mm256_fmadd_ps(h10, tan0_x_hi, _mm256_mul_ps(h11, tan1_x_hi))));
__m256 result_y_lo = _mm256_fmadd_ps(h00, pos0_y_lo, _mm256_fmadd_ps(h01, pos1_y_lo, _mm256_fmadd_ps(h10, tan0_y_lo, _mm256_mul_ps(h11, tan1_y_lo))));
__m256 result_y_hi = _mm256_fmadd_ps(h00, pos0_y_hi, _mm256_fmadd_ps(h01, pos1_y_hi, _mm256_fmadd_ps(h10, tan0_y_hi, _mm256_mul_ps(h11, tan1_y_hi))));
__m256 result_z_lo = _mm256_fmadd_ps(h00, pos0_z_lo, _mm256_fmadd_ps(h01, pos1_z_lo, _mm256_fmadd_ps(h10, tan0_z_lo, _mm256_mul_ps(h11, tan1_z_lo))));
__m256 result_z_hi = _mm256_fmadd_ps(h00, pos0_z_hi, _mm256_fmadd_ps(h01, pos1_z_hi, _mm256_fmadd_ps(h10, tan0_z_hi, _mm256_mul_ps(h11, tan1_z_hi))));
__m256 result_scl_lo = _mm256_fmadd_ps(h00, scl0_lo, _mm256_fmadd_ps(h01, scl1_lo, _mm256_fmadd_ps(h10, tan0_s_lo, _mm256_mul_ps(h11, tan1_s_lo))));
__m256 result_scl_hi = _mm256_fmadd_ps(h00, scl0_hi, _mm256_fmadd_ps(h01, scl1_hi, _mm256_fmadd_ps(h10, tan0_s_hi, _mm256_mul_ps(h11, tan1_s_hi))));
__m256 result_rot_x_lo = _mm256_fmadd_ps(h00, rot0_x_lo, _mm256_fmadd_ps(h01, rot1_x_lo, _mm256_fmadd_ps(h10, tan0_rx_lo, _mm256_mul_ps(h11, tan1_rx_lo))));
__m256 result_rot_x_hi = _mm256_fmadd_ps(h00, rot0_x_hi, _mm256_fmadd_ps(h01, rot1_x_hi, _mm256_fmadd_ps(h10, tan0_rx_hi, _mm256_mul_ps(h11, tan1_rx_hi))));
__m256 result_rot_y_lo = _mm256_fmadd_ps(h00, rot0_y_lo, _mm256_fmadd_ps(h01, rot1_y_lo, _mm256_fmadd_ps(h10, tan0_ry_lo, _mm256_mul_ps(h11, tan1_ry_lo))));
__m256 result_rot_y_hi = _mm256_fmadd_ps(h00, rot0_y_hi, _mm256_fmadd_ps(h01, rot1_y_hi, _mm256_fmadd_ps(h10, tan0_ry_hi, _mm256_mul_ps(h11, tan1_ry_hi))));
__m256 result_rot_z_lo = _mm256_fmadd_ps(h00, rot0_z_lo, _mm256_fmadd_ps(h01, rot1_z_lo, _mm256_fmadd_ps(h10, tan0_rz_lo, _mm256_mul_ps(h11, tan1_rz_lo))));
__m256 result_rot_z_hi = _mm256_fmadd_ps(h00, rot0_z_hi, _mm256_fmadd_ps(h01, rot1_z_hi, _mm256_fmadd_ps(h10, tan0_rz_hi, _mm256_mul_ps(h11, tan1_rz_hi))));
// Stream to output (layout: x, y, z, scl, rot_x, rot_y, rot_z per joint)
_mm256_stream_ps(out_ptr + j + 0 * 256, result_x_lo);
_mm256_stream_ps(out_ptr + j + 8 + 0 * 256, result_x_hi);
_mm256_stream_ps(out_ptr + j + 1 * 256, result_y_lo);
_mm256_stream_ps(out_ptr + j + 8 + 1 * 256, result_y_hi);
_mm256_stream_ps(out_ptr + j + 2 * 256, result_z_lo);
_mm256_stream_ps(out_ptr + j + 8 + 2 * 256, result_z_hi);
_mm256_stream_ps(out_ptr + j + 3 * 256, result_scl_lo);
_mm256_stream_ps(out_ptr + j + 8 + 3 * 256, result_scl_hi);
_mm256_stream_ps(out_ptr + j + 4 * 256, result_rot_x_lo);
_mm256_stream_ps(out_ptr + j + 8 + 4 * 256, result_rot_x_hi);
_mm256_stream_ps(out_ptr + j + 5 * 256, result_rot_y_lo);
_mm256_stream_ps(out_ptr + j + 8 + 5 * 256, result_rot_y_hi);
_mm256_stream_ps(out_ptr + j + 6 * 256, result_rot_z_lo);
_mm256_stream_ps(out_ptr + j + 8 + 6 * 256, result_rot_z_hi);
}
out_ptr += 256 * ANIM_TRK_ELM_CNT; // Advance to next mdl’s output
}
}
static void
anim_update(float* out, int mdl_start, int mdl_cnt) {
align32 struct anm_frame decoded_data[BATCH_SIZE]; // 112 KB (28 KB * 4) move to heap
int remainder = (mdl_cnt % BATCH_SIZE);
int batch_cnt = mdl_cnt / BATCH_SIZE + !!remainder;
for (int b = 0; b < batch_cnt; b++) {
align32 float t_scalar[BATCH_SIZE] = {0}; // 16 B
align32 float h_coeffs[BATCH_SIZE * 4] = {0}; // 64 B: h00, h01, h10, h11
int start = mdl_start + b * BATCH_SIZE;
int cnt = (start + BATCH_SIZE <= mdl_start + mdl_cnt) ? BATCH_SIZE : remainder;
anim_decode_frames(decoded_data, start, cnt);
compute_hermite_coeffs(t_scalar, h_coeffs, start, cnt);
anim_interpolate_frames(decoded_data, out + start * 256 * ANIM_TRK_ELM_CNT, start, cnt, h_coeffs);
}
_mm_sfence(); // Ensure all streaming stores complete
}
/* ---------------------------------------------------------------------------
* Thread Pool
* ---------------------------------------------------------------------------
*/
#define MAX_THREADS 128
#define QUEUE_SIZE 1024
#define BATCH_SIZE 4
// tsk structure
struct thrd_pool_job {
float* out; // Output buffer for this chunk
int start; // Starting model index
int cnt; // Number of models to process
int valid; // Is this tsk active?
};
// Thread pool structure
struct thrd_pool {
pthread_t threads[MAX_THREADS];
struct thrd_pool_job queue[QUEUE_SIZE];
int thread_cnt;
int tsk_cnt;
int tsks_done;
int shutdown;
pthread_mutex_t mutex;
pthread_cond_t cond_not_empty;
pthread_cond_t cond_done;
};
static struct thrd_pool pool;
// Worker thread function
static void*
thrd_wrkr(void* arg) {
struct thrd_pool* pool = (struct thrd_pool*)arg;
while (1) {
struct thrd_pool_job tsk;
int got_tsk = 0;
pthread_mutex_lock(&pool->mutex);
while (!pool->shutdown && pool->tsk_cnt == 0) {
pthread_cond_wait(&pool->cond_not_empty, &pool->mutex);
}
if (pool->shutdown && pool->tsk_cnt == 0) {
pthread_mutex_unlock(&pool->mutex);
break;
}
if (pool->tsk_cnt > 0) {
tsk = pool->queue[0];
for (int i = 1; i < pool->tsk_cnt; i++) {
pool->queue[i - 1] = pool->queue[i];
}
pool->tsk_cnt--;
got_tsk = 1;
}
pthread_mutex_unlock(&pool->mutex);
if (got_tsk) {
anim_update(tsk.out, tsk.start, tsk.cnt);
pthread_mutex_lock(&pool->mutex);
pool->tsks_done++;
if (pool->tsks_done == pool->tsk_cnt + pool->tsks_done) {
pthread_cond_broadcast(&pool->cond_done);
}
pthread_mutex_unlock(&pool->mutex);
}
}
return 0;
}
static void
thrd_pool_init(int num_threads) {
num_threads = (num_threads > MAX_THREADS) ? MAX_THREADS : num_threads;
pthread_mutex_init(&pool.mutex, 0);
pthread_cond_init(&pool.cond_not_empty, 0);
pthread_cond_init(&pool.cond_done, 0);
pool.thread_cnt = num_threads;
pool.tsk_cnt = 0;
pool.tsks_done = 0;
pool.shutdown = 0;
for (int i = 0; i < num_threads; i++) {
pthread_create(&pool.threads[i], 0, thrd_wrkr, &pool);
}
}
static void
thrd_pool_submit(float* out, int start, int cnt) {
pthread_mutex_lock(&pool.mutex);
if (pool.tsk_cnt < QUEUE_SIZE) {
struct thrd_pool_job tsk = {out, start, cnt, 1};
pool->queue[pool->tsk_cnt++] = tsk;
pthread_cond_signal(&pool->cond_not_empty);
} else {
fprintf(stderr, "tsk queue full!\n");
}
pthread_mutex_unlock(&pool->mutex);
}
static void
thrd_pool_wait(void) {
pthread_mutex_lock(&pool.mutex);
while (pool.tsk_cnt > 0 || pool.tsks_done < pool.tsk_cnt + pool.tsks_done) {
pthread_cond_wait(&pool->cond_done, &pool->mutex);
}
pool.tsks_done = 0;
pthread_mutex_unlock(&pool->mutex);
}
static void
thrd_pool_shutdown(void) {
pthread_mutex_lock(&pool.mutex);
pool.shutdown = 1;
pthread_cond_broadcast(&pool->cond_not_empty);
pthread_mutex_unlock(&pool->mutex);
for (int i = 0; i < pool.thread_cnt; i++) {
pthread_join(pool.threads[i], 0);
}
pthread_mutex_destroy(&pool.mutex);
pthread_cond_destroy(&pool.cond_not_empty);
pthread_cond_destroy(&pool.cond_done);
}
// Multi-threaded anim_update wrapper with batch alignment
static void
job_anim_update(float* out, int total_models) {
int cores = pool.thread_cnt;
int base_models_per_thread = total_models / cores; // Integer division
int remainder = total_models % cores; // Leftover models
int processed_models = 0;
for (int i = 0; i < cores; i++) {
int start = processed_models;
int cnt = base_models_per_thread + (i < remainder ? 1 : 0); // Distribute remainder
if (cnt > 0) {
float* chunk_out = out + start * 256 * 7;
// Adjust cnt to be multiple of BATCH_SIZE, handle remainder later if needed
int aligned_cnt = (cnt / BATCH_SIZE) * BATCH_SIZE;
if (aligned_cnt > 0) {
thrd_pool_submit(chunk_out, start, aligned_cnt);
}
processed_models += aligned_cnt;
}
}
// Handle remaining models (not aligned to BATCH_SIZE)
int leftover = total_models - processed_models;
if (leftover > 0) {
float* chunk_out = out + processed_models * 256 * 7;
thrd_pool_submit(chunk_out, processed_models, leftover);
}
thrd_pool_wait();
}
extern int
main(void) {
thrd_pool_init(8); // Example: 8 cores
int total_models = 77601; // Not multiple of 4 or 8
float* out = malloc(total_models * 256 * 7 * sizeof(float));
job_anim_update(out, total_models);
free(out);
thrd_pool_shutdown();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment