Last active
April 7, 2025 13:17
-
-
Save vurtun/6a4f284f75e3133586b04b6692dac0fd to your computer and use it in GitHub Desktop.
animation sampling
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <immintrin.h> // AVX, includes SSE2-SSE4.2 | |
#include <assert.h> | |
#include <math.h> | |
#include <stdint.h> | |
#include <pthread.h> | |
#include <stdlib.h> | |
#include <stdio.h> | |
// Constants | |
#define MAX_MDLS 256 | |
#define MAX_ANIMS 256 | |
#define BATCH_SIZE 4 | |
#define MAX_JOINTS 256 | |
#define ANIM_TRK_ELM_CNT 6 // x, y, z, scl, rot_x, rot_y, rot_z | |
#define align32 __attribute__((aligned(32))) | |
// Animation track types | |
enum anm_trks { | |
ANM_TRK_POS, // 3 components (x, y, z) | |
ANM_TRK_ROT, // 1 quaternion (32-bit encoded) | |
ANM_TRK_SCL, // 1 component (uniform scale) | |
ANM_TRK_CNT | |
}; | |
struct anm { | |
int frame_cnt; // Total number of frames | |
float scl[ANM_TRK_CNT]; // Scaling factors for pos, rot, scl | |
struct { | |
uint16_t* pos; // 6 shorts: x, y, z, tx, ty, tz | |
uint32_t* rot; // 8 bytes: x, y, z, w, tx, ty, tz, tw | |
uint16_t* scl; // 2 shorts: s, ts | |
} keys; | |
struct { | |
align32 int* key_off[ANM_TRK_CNT]; // Per-keyframe, per-joint offsets into keys | |
int* frame_to_key[ANM_TRK_CNT]; // Maps frame index to keyframe index (sparse keyframes) | |
int key_cnt[ANM_TRK_CNT]; // Number of keyframes per track | |
} blks; | |
}; | |
struct mdl_anm { | |
int anim; | |
int frame; | |
float exact_frame; | |
}; | |
struct mdl { | |
int x; | |
}; | |
struct mdl mdls[MAX_MDLS]; | |
struct anm anms[MAX_ANIMS]; | |
struct mdl_anm mdl_anms[MAX_MDLS]; | |
// SOA structure for decoded data, fixed 256 joints | |
struct anm_frame { | |
float pos[2][3][256]; // [frame: 0/1][x,y,z][joint] | |
float rot[2][3][256]; // [frame: 0/1][x,y,z][joint] (no w) | |
float scl[2][256]; // [frame: 0/1][joint] | |
float tan_pos[2][3][256]; // Tangents for position | |
float tan_rot[2][3][256]; // Tangents for rotation (no tw) | |
float tan_scl[2][256]; // Tangents for scale | |
}; | |
static inline void | |
qdec_avx(__m256* qout_x, __m256* qout_y, __m256* qout_z, __m256* qout_w, const uint32_t* qin) { | |
const __m256 half = _mm256_set1_ps(0.707106781f); | |
const __m256 inv_msk = _mm256_set1_ps(1.0f / 511.0f); | |
const __m256 one = _mm256_set1_ps(1.0f); | |
const __m256 zero = _mm256_setzero_ps(); | |
__m256i q = _mm256_load_si256((__m256i*)qin); | |
__m256i top = _mm256_srli_epi32(q, 30); | |
__m256i mask = _mm256_set1_epi32(511); | |
__m256 qres[4] = {zero, zero, zero, zero}; | |
__m256 sqr = zero; | |
// Component 0 | |
__m256i mag0 = _mm256_and_si256(q, mask); | |
__m256i negbit0 = _mm256_and_si256(_mm256_srli_epi32(q, 9), _mm256_set1_epi32(1)); | |
q = _mm256_srli_epi32(q, 10); | |
__m256 val0 = _mm256_mul_ps(_mm256_cvtepi32_ps(mag0), inv_msk); | |
val0 = _mm256_mul_ps(val0, half); | |
sqr = _mm256_fmadd_ps(val0, val0, sqr); | |
__m256 sign0 = _mm256_sub_ps(zero, _mm256_cvtepi32_ps(negbit0)); | |
val0 = _mm256_fmadd_ps(sign0, val0, val0); | |
__m256 mask_x0 = _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(0))); | |
qres[0] = _mm256_blendv_ps(qres[0], val0, _mm256_andnot_ps(mask_x0, one)); | |
// Component 1 | |
__m256i mag1 = _mm256_and_si256(q, mask); | |
__m256i negbit1 = _mm256_and_si256(_mm256_srli_epi32(q, 9), _mm256_set1_epi32(1)); | |
q = _mm256_srli_epi32(q, 10); | |
__m256 val1 = _mm256_mul_ps(_mm256_cvtepi32_ps(mag1), inv_msk); | |
val1 = _mm256_mul_ps(val1, half); | |
sqr = _mm256_fmadd_ps(val1, val1, sqr); | |
__m256 sign1 = _mm256_sub_ps(zero, _mm256_cvtepi32_ps(negbit1)); | |
val1 = _mm256_fmadd_ps(sign1, val1, val1); | |
__m256 mask_y1 = _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(1))); | |
qres[1] = _mm256_blendv_ps(qres[1], val1, _mm256_andnot_ps(mask_y1, one)); | |
// Component 2 | |
__m256i mag2 = _mm256_and_si256(q, mask); | |
__m256i negbit2 = _mm256_and_si256(_mm256_srli_epi32(q, 9), _mm256_set1_epi32(1)); | |
q = _mm256_srli_epi32(q, 10); | |
__m256 val2 = _mm256_mul_ps(_mm256_cvtepi32_ps(mag2), inv_msk); | |
val2 = _mm256_mul_ps(val2, half); | |
sqr = _mm256_fmadd_ps(val2, val2, sqr); | |
__m256 sign2 = _mm256_sub_ps(zero, _mm256_cvtepi32_ps(negbit2)); | |
val2 = _mm256_fmadd_ps(sign2, val2, val2); | |
__m256 mask_z2 = _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(2))); | |
qres[2] = _mm256_blendv_ps(qres[2], val2, _mm256_andnot_ps(mask_z2, one)); | |
// Missing component with rsqrt + rcp | |
__m256 diff = _mm256_max_ps(_mm256_sub_ps(one, sqr), zero); // 1 - sqr | |
__m256 root = _mm256_rcp_ps(_mm256_rsqrt_ps(diff)); // ≈ √(1 - sqr) | |
qres[0] = _mm256_blendv_ps(qres[0], root, _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(0)))); | |
qres[1] = _mm256_blendv_ps(qres[1], root, _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(1)))); | |
qres[2] = _mm256_blendv_ps(qres[2], root, _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(2)))); | |
qres[3] = _mm256_blendv_ps(qres[3], root, _mm256_castsi256_ps(_mm256_cmpeq_epi32(top, _mm256_set1_epi32(3)))); | |
*qout_x = qres[0]; | |
*qout_y = qres[1]; | |
*qout_z = qres[2]; | |
*qout_w = qres[3]; | |
} | |
static void | |
anim_decode_frames(struct anm_frame* decoded_data, int start_mdl, int mdl_cnt) { | |
for (int m = start_mdl; m < start_mdl + mdl_cnt && m < MAX_MDLS; ++m) { | |
int batch_idx = m - start_mdl; | |
const struct anm* anm = &anms[mdl_anms[m].anim]; | |
int frame = mdl_anms[m].frame; | |
int mask = (frame + 1 < anm->frame_cnt); | |
int frame_next = frame + mask; | |
int key_idx_pos[2] = {anm->blks.frame_to_key[ANM_TRK_POS][frame], anm->blks.frame_to_key[ANM_TRK_POS][frame_next]}; | |
int key_idx_rot[2] = {anm->blks.frame_to_key[ANM_TRK_ROT][frame], anm->blks.frame_to_key[ANM_TRK_ROT][frame_next]}; | |
int key_idx_scl[2] = {anm->blks.frame_to_key[ANM_TRK_SCL][frame], anm->blks.frame_to_key[ANM_TRK_SCL][frame_next]}; | |
const __m256 pos_scl = _mm256_set1_ps(anm->scl[ANM_TRK_POS]); | |
const __m256 rot_scl = _mm256_set1_ps(anm->scl[ANM_TRK_ROT]); | |
const __m256 scl_scl = _mm256_set1_ps(anm->scl[ANM_TRK_SCL]); | |
for (int j = 0; j < 256; j += 16) { | |
// Base indices for frame 0 and frame 1, split into lo (0-7) and hi (8-15) | |
int base_idx_pos_lo[2] = {key_idx_pos[0] * 256 + j, key_idx_pos[1] * 256 + j}; | |
int base_idx_pos_hi[2] = {key_idx_pos[0] * 256 + j + 8, key_idx_pos[1] * 256 + j + 8}; | |
int base_idx_rot_lo[2] = {key_idx_rot[0] * 256 + j, key_idx_rot[1] * 256 + j}; | |
int base_idx_rot_hi[2] = {key_idx_rot[0] * 256 + j + 8, key_idx_rot[1] * 256 + j + 8}; | |
int base_idx_scl_lo[2] = {key_idx_scl[0] * 256 + j, key_idx_scl[1] * 256 + j}; | |
int base_idx_scl_hi[2] = {key_idx_scl[0] * 256 + j + 8, key_idx_scl[1] * 256 + j + 8}; | |
// Prefetch next 16 joints | |
_mm_prefetch((const char*)&anm->keys.pos[key_idx_pos[0] * 6 + j + 16], _MM_HINT_T0); | |
_mm_prefetch((const char*)&anm->keys.pos[key_idx_pos[1] * 6 + j + 16], _MM_HINT_T0); | |
_mm_prefetch((const char*)&anm->keys.rot[key_idx_rot[0] * 8 + j + 16], _MM_HINT_T0); | |
_mm_prefetch((const char*)&anm->keys.rot[key_idx_rot[1] * 8 + j + 16], _MM_HINT_T0); | |
_mm_prefetch((const char*)&anm->keys.scl[key_idx_scl[0] * 2 + j + 16], _MM_HINT_T0); | |
_mm_prefetch((const char*)&anm->keys.scl[key_idx_scl[1] * 2 + j + 16], _MM_HINT_T0); | |
// Position (16 joints, lo/hi) | |
__m256i k0_pos_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_POS][base_idx_pos_lo[0]]); | |
__m256i k0_pos_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_POS][base_idx_pos_hi[0]]); | |
__m256i k1_pos_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_POS][base_idx_pos_lo[1]]); | |
__m256i k1_pos_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_POS][base_idx_pos_hi[1]]); | |
__m256i k0_pos_base_lo = _mm256_mullo_epi32(k0_pos_lo, _mm256_set1_epi32(6)); | |
__m256i k0_pos_base_hi = _mm256_mullo_epi32(k0_pos_hi, _mm256_set1_epi32(6)); | |
__m256i k1_pos_base_lo = _mm256_mullo_epi32(k1_pos_lo, _mm256_set1_epi32(6)); | |
__m256i k1_pos_base_hi = _mm256_mullo_epi32(k1_pos_hi, _mm256_set1_epi32(6)); | |
__m256 pos0_x_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, k0_pos_base_lo, 2)), pos_scl); | |
__m256 pos0_x_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, k0_pos_base_hi, 2)), pos_scl); | |
__m256 pos0_y_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(1)), 2)), pos_scl); | |
__m256 pos0_y_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(1)), 2)), pos_scl); | |
__m256 pos0_z_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(2)), 2)), pos_scl); | |
__m256 pos0_z_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(2)), 2)), pos_scl); | |
__m256 tan0_x_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(3)), 2)), pos_scl); | |
__m256 tan0_x_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(3)), 2)), pos_scl); | |
__m256 tan0_y_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(4)), 2)), pos_scl); | |
__m256 tan0_y_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(4)), 2)), pos_scl); | |
__m256 tan0_z_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_lo, _mm256_set1_epi32(5)), 2)), pos_scl); | |
__m256 tan0_z_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k0_pos_base_hi, _mm256_set1_epi32(5)), 2)), pos_scl); | |
__m256 pos1_x_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, k1_pos_base_lo, 2)), pos_scl); | |
__m256 pos1_x_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, k1_pos_base_hi, 2)), pos_scl); | |
__m256 pos1_y_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(1)), 2)), pos_scl); | |
__m256 pos1_y_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(1)), 2)), pos_scl); | |
__m256 pos1_z_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(2)), 2)), pos_scl); | |
__m256 pos1_z_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(2)), 2)), pos_scl); | |
__m256 tan1_x_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(3)), 2)), pos_scl); | |
__m256 tan1_x_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(3)), 2)), pos_scl); | |
__m256 tan1_y_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(4)), 2)), pos_scl); | |
__m256 tan1_y_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(4)), 2)), pos_scl); | |
__m256 tan1_z_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_lo, _mm256_set1_epi32(5)), 2)), pos_scl); | |
__m256 tan1_z_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.pos, _mm256_add_epi32(k1_pos_base_hi, _mm256_set1_epi32(5)), 2)), pos_scl); | |
// Scale (2 shorts) | |
__m256i k0_scl_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_SCL][base_idx_scl_lo[0]]); | |
__m256i k0_scl_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_SCL][base_idx_scl_hi[0]]); | |
__m256i k1_scl_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_SCL][base_idx_scl_lo[1]]); | |
__m256i k1_scl_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_SCL][base_idx_scl_hi[1]]); | |
__m256i k0_scl_base_lo = _mm256_mullo_epi32(k0_scl_lo, _mm256_set1_epi32(2)); | |
__m256i k0_scl_base_hi = _mm256_mullo_epi32(k0_scl_hi, _mm256_set1_epi32(2)); | |
__m256i k1_scl_base_lo = _mm256_mullo_epi32(k1_scl_lo, _mm256_set1_epi32(2)); | |
__m256i k1_scl_base_hi = _mm256_mullo_epi32(k1_scl_hi, _mm256_set1_epi32(2)); | |
__m256 scl0_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, k0_scl_base_lo, 2)), scl_scl); | |
__m256 scl0_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, k0_scl_base_hi, 2)), scl_scl); | |
__m256 tan0_s_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, _mm256_add_epi32(k0_scl_base_lo, _mm256_set1_epi32(1)), 2)), scl_scl); | |
__m256 tan0_s_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, _mm256_add_epi32(k0_scl_base_hi, _mm256_set1_epi32(1)), 2)), scl_scl); | |
__m256 scl1_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, k1_scl_base_lo, 2)), scl_scl); | |
__m256 scl1_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, k1_scl_base_hi, 2)), scl_scl); | |
__m256 tan1_s_lo = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, _mm256_add_epi32(k1_scl_base_lo, _mm256_set1_epi32(1)), 2)), scl_scl); | |
__m256 tan1_s_hi = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_i32gather_epi32((int*)anm->keys.scl, _mm256_add_epi32(k1_scl_base_hi, _mm256_set1_epi32(1)), 2)), scl_scl); | |
// Rotation decoding (16 joints, lo/hi) | |
__m256i k0_rot_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_ROT][base_idx_rot_lo[0]]); | |
__m256i k0_rot_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_ROT][base_idx_rot_hi[0]]); | |
__m256i k1_rot_lo = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_ROT][base_idx_rot_lo[1]]); | |
__m256i k1_rot_hi = _mm256_load_si256((__m256i*)&anm->blks.key_off[ANM_TRK_ROT][base_idx_rot_hi[1]]); | |
__m256i k0_rot_base_lo = _mm256_mullo_epi32(k0_rot_lo, _mm256_set1_epi32(8)); | |
__m256i k0_rot_base_hi = _mm256_mullo_epi32(k0_rot_hi, _mm256_set1_epi32(8)); | |
__m256i k1_rot_base_lo = _mm256_mullo_epi32(k1_rot_lo, _mm256_set1_epi32(8)); | |
__m256i k1_rot_base_hi = _mm256_mullo_epi32(k1_rot_hi, _mm256_set1_epi32(8)); | |
align32 uint32_t rot0_data_lo[8], rot0_data_hi[8], rot1_data_lo[8], rot1_data_hi[8]; | |
_mm256_store_si256((__m256i*)rot0_data_lo, _mm256_i32gather_epi32((int*)anm->keys.rot, k0_rot_base_lo, 4)); | |
_mm256_store_si256((__m256i*)rot0_data_hi, _mm256_i32gather_epi32((int*)anm->keys.rot, k0_rot_base_hi, 4)); | |
_mm256_store_si256((__m256i*)rot1_data_lo, _mm256_i32gather_epi32((int*)anm->keys.rot, k1_rot_base_lo, 4)); | |
_mm256_store_si256((__m256i*)rot1_data_hi, _mm256_i32gather_epi32((int*)anm->keys.rot, k1_rot_base_hi, 4)); | |
__m256 rot0_x_lo, rot0_y_lo, rot0_z_lo, rot0_w_lo, rot1_x_lo, rot1_y_lo, rot1_z_lo, rot1_w_lo; | |
__m256 rot0_x_hi, rot0_y_hi, rot0_z_hi, rot0_w_hi, rot1_x_hi, rot1_y_hi, rot1_z_hi, rot1_w_hi; | |
qdec_avx(&rot0_x_lo, &rot0_y_lo, &rot0_z_lo, &rot0_w_lo, rot0_data_lo); | |
qdec_avx(&rot0_x_hi, &rot0_y_hi, &rot0_z_hi, &rot0_w_hi, rot0_data_hi); | |
qdec_avx(&rot1_x_lo, &rot1_y_lo, &rot1_z_lo, &rot1_w_lo, rot1_data_lo); | |
qdec_avx(&rot1_x_hi, &rot1_y_hi, &rot1_z_hi, &rot1_w_hi, rot1_data_hi); | |
_mm256_store_si256((__m256i*)rot0_data_lo, _mm256_i32gather_epi32((int*)anm->keys.rot, _mm256_add_epi32(k0_rot_base_lo, _mm256_set1_epi32(4)), 4)); | |
_mm256_store_si256((__m256i*)rot0_data_hi, _mm256_i32gather_epi32((int*)anm->keys.rot, _mm256_add_epi32(k0_rot_base_hi, _mm256_set1_epi32(4)), 4)); | |
_mm256_store_si256((__m256i*)rot1_data_lo, _mm256_i32gather_epi32((int*)anm->keys.rot, _mm256_add_epi32(k1_rot_base_lo, _mm256_set1_epi32(4)), 4)); | |
_mm256_store_si256((__m256i*)rot1_data_hi, _mm256_i32gather_epi32((int*)anm->keys.rot, _mm256_add_epi32(k1_rot_base_hi, _mm256_set1_epi32(4)), 4)); | |
__m256 tan0_rx_lo, tan0_ry_lo, tan0_rz_lo, tan0_rw_lo, tan1_rx_lo, tan1_ry_lo, tan1_rz_lo, tan1_rw_lo; | |
__m256 tan0_rx_hi, tan0_ry_hi, tan0_rz_hi, tan0_rw_hi, tan1_rx_hi, tan1_ry_hi, tan1_rz_hi, tan1_rw_hi; | |
qdec_avx(&tan0_rx_lo, &tan0_ry_lo, &tan0_rz_lo, &tan0_rw_lo, rot0_data_lo); | |
qdec_avx(&tan0_rx_hi, &tan0_ry_hi, &tan0_rz_hi, &tan0_rw_hi, rot0_data_hi); | |
qdec_avx(&tan1_rx_lo, &tan1_ry_lo, &tan1_rz_lo, &tan1_rw_lo, rot1_data_lo); | |
qdec_avx(&tan1_rx_hi, &tan1_ry_hi, &tan1_rz_hi, &tan1_rw_hi, rot1_data_hi); | |
rot0_x_lo = _mm256_mul_ps(rot0_x_lo, rot_scl); | |
rot0_y_lo = _mm256_mul_ps(rot0_y_lo, rot_scl); | |
rot0_z_lo = _mm256_mul_ps(rot0_z_lo, rot_scl); | |
rot0_x_hi = _mm256_mul_ps(rot0_x_hi, rot_scl); | |
rot0_y_hi = _mm256_mul_ps(rot0_y_hi, rot_scl); | |
rot0_z_hi = _mm256_mul_ps(rot0_z_hi, rot_scl); | |
tan0_rx_lo = _mm256_mul_ps(tan0_rx_lo, rot_scl); | |
tan0_ry_lo = _mm256_mul_ps(tan0_ry_lo, rot_scl); | |
tan0_rz_lo = _mm256_mul_ps(tan0_rz_lo, rot_scl); | |
tan0_rx_hi = _mm256_mul_ps(tan0_rx_hi, rot_scl); | |
tan0_ry_hi = _mm256_mul_ps(tan0_ry_hi, rot_scl); | |
tan0_rz_hi = _mm256_mul_ps(tan0_rz_hi, rot_scl); | |
rot1_x_lo = _mm256_mul_ps(rot1_x_lo, rot_scl); | |
rot1_y_lo = _mm256_mul_ps(rot1_y_lo, rot_scl); | |
rot1_z_lo = _mm256_mul_ps(rot1_z_lo, rot_scl); | |
rot1_x_hi = _mm256_mul_ps(rot1_x_hi, rot_scl); | |
rot1_y_hi = _mm256_mul_ps(rot1_y_hi, rot_scl); | |
rot1_z_hi = _mm256_mul_ps(rot1_z_hi, rot_scl); | |
tan1_rx_lo = _mm256_mul_ps(tan1_rx_lo, rot_scl); | |
tan1_ry_lo = _mm256_mul_ps(tan1_ry_lo, rot_scl); | |
tan1_rz_lo = _mm256_mul_ps(tan1_rz_lo, rot_scl); | |
tan1_rx_hi = _mm256_mul_ps(tan1_rx_hi, rot_scl); | |
tan1_ry_hi = _mm256_mul_ps(tan1_ry_hi, rot_scl); | |
tan1_rz_hi = _mm256_mul_ps(tan1_rz_hi, rot_scl); | |
// Store rotation and tangent outputs (16 joints, lo/hi) | |
_mm256_store_ps(&decoded_data[batch_idx].pos[0][0][j], pos0_x_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[0][0][j+8], pos0_x_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[0][1][j], pos0_y_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[0][1][j+8], pos0_y_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[0][2][j], pos0_z_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[0][2][j+8], pos0_z_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[1][0][j], pos1_x_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[1][0][j+8], pos1_x_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[1][1][j], pos1_y_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[1][1][j+8], pos1_y_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[1][2][j], pos1_z_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].pos[1][2][j+8], pos1_z_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][0][j], tan0_x_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][0][j], tan0_x_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][1][j], tan0_y_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][1][j], tan0_y_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][2][j], tan0_z_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[0][2][j], tan0_z_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][0][j], tan1_x_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][0][j], tan1_x_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][1][j], tan1_y_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][1][j], tan1_y_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][2][j], tan1_z_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_pos[1][2][j], tan1_z_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[0][0][j], rot0_x_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[0][0][j + 8], rot0_x_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[0][1][j], rot0_y_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[0][1][j + 8], rot0_y_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[0][2][j], rot0_z_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[0][2][j + 8], rot0_z_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[1][0][j], rot1_x_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[1][0][j + 8], rot1_x_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[1][1][j], rot1_y_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[1][1][j + 8], rot1_y_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[1][2][j], rot1_z_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].rot[1][2][j + 8], rot1_z_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][0][j], tan0_rx_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][0][j + 8], tan0_rx_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][1][j], tan0_ry_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][1][j + 8], tan0_ry_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][2][j], tan0_rz_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[0][2][j + 8], tan0_rz_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][0][j], tan1_rx_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][0][j + 8], tan1_rx_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][1][j], tan1_ry_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][1][j + 8], tan1_ry_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][2][j], tan1_rz_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_rot[1][2][j + 8], tan1_rz_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].scl[0][j], scl0_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].scl[0][j+8], scl0_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].scl[1][j], scl1_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].scl[1][j+8], scl1_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_scl[0][j], tan0_s_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_scl[0][j+8], tan0_s_hi); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_scl[1][j], tan1_s_lo); | |
_mm256_store_ps(&decoded_data[batch_idx].tan_scl[1][j+8], tan1_s_hi); | |
} | |
} | |
} | |
static void | |
compute_hermite_coeffs(float* restrict t_scalar, float* restrict h_coeffs, int start_mdl, int mdl_cnt) { | |
assert(mdl_cnt <= BATCH_SIZE); // Ensure we don’t exceed batch size (4) | |
// Compute t values for each mdl in the batch | |
for (int m = start_mdl; m < start_mdl + mdl_cnt && m < MAX_MDLS; ++m) { | |
int idx = m - start_mdl; | |
const struct anm* anm = &anms[mdl_anms[m].anim]; | |
int frame = mdl_anms[m].frame; | |
int frame_next = (frame + 1 < anm->frame_cnt ? frame + 1 : frame); | |
t_scalar[idx] = (mdl_anms[m].exact_frame - (float)frame) / ((float)frame_next - (float)frame); | |
} | |
// Vectorized Hermite coefficient computation for 4 mdls | |
__m128 t = _mm_load_ps(t_scalar); // Load 4 t values | |
__m128 zero = _mm_setzero_ps(); | |
__m128 one = _mm_set1_ps(1.0f); | |
__m128 t_clamped = _mm_min_ps(_mm_max_ps(t, zero), one); // Clamp t to [0, 1] | |
__m128 t2 = _mm_mul_ps(t_clamped, t_clamped); // t² | |
__m128 t3 = _mm_fmadd_ps(t2, t_clamped, zero); // t³ = t² * t + 0 | |
__m128 two = _mm_set1_ps(2.0f); | |
__m128 neg_three = _mm_set1_ps(-3.0f); | |
// Compute Hermite basis functions | |
__m128 h00 = _mm_fmadd_ps(two, t3, _mm_fmadd_ps(neg_three, t2, one)); // 2t³ - 3t² + 1 | |
__m128 h01 = _mm_fmadd_ps(_mm_set1_ps(-2.0f), t3, _mm_mul_ps(_mm_set1_ps(3.0f), t2)); // -2t³ + 3t² | |
__m128 h10 = _mm_fmadd_ps(t3, one, _mm_fmadd_ps(t2, _mm_set1_ps(-2.0f), t_clamped)); // t³ - 2t² + t | |
__m128 h11 = _mm_fmsub_ps(t3, one, t2); // t³ - t² | |
// Store coefficients grouped by type | |
_mm_store_ps(&h_coeffs[0], h00); // h00_m0-m3 | |
_mm_store_ps(&h_coeffs[4], h01); // h01_m0-m3 | |
_mm_store_ps(&h_coeffs[8], h10); // h10_m0-m3 | |
_mm_store_ps(&h_coeffs[12], h11); // h11_m0-m3 | |
} | |
static void | |
anim_interpolate_frames(const struct anm_frame* decoded_data, float* out, int start_mdl, int mdl_cnt, const float* restrict h_coeffs) { | |
float* out_ptr = out; | |
// Process up to 4 mdls (BATCH_SIZE = 4) | |
for (int m = start_mdl; m < start_mdl + mdl_cnt && m < MAX_MDLS; ++m) { | |
int batch_idx = m - start_mdl; | |
// Load Hermite coefficients for this mdl (broadcast to all 8 lanes) | |
__m256 h00 = _mm256_set1_ps(h_coeffs[0 + batch_idx]); // h00 block | |
__m256 h01 = _mm256_set1_ps(h_coeffs[4 + batch_idx]); // h01 block | |
__m256 h10 = _mm256_set1_ps(h_coeffs[8 + batch_idx]); // h10 block | |
__m256 h11 = _mm256_set1_ps(h_coeffs[12 + batch_idx]); // h11 block | |
for (int j = 0; j < 256; j += 16) { | |
// Load decoded data for 16 joints (lo: 0-7, hi: 8-15) for this mdl | |
__m256 pos0_x_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[0][0][j]); | |
__m256 pos0_x_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[0][0][j + 8]); | |
__m256 pos0_y_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[0][1][j]); | |
__m256 pos0_y_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[0][1][j + 8]); | |
__m256 pos0_z_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[0][2][j]); | |
__m256 pos0_z_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[0][2][j + 8]); | |
__m256 pos1_x_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[1][0][j]); | |
__m256 pos1_x_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[1][0][j + 8]); | |
__m256 pos1_y_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[1][1][j]); | |
__m256 pos1_y_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[1][1][j + 8]); | |
__m256 pos1_z_lo = _mm256_load_ps(&decoded_data[batch_idx].pos[1][2][j]); | |
__m256 pos1_z_hi = _mm256_load_ps(&decoded_data[batch_idx].pos[1][2][j + 8]); | |
__m256 tan0_x_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][0][j]); | |
__m256 tan0_x_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][0][j + 8]); | |
__m256 tan0_y_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][1][j]); | |
__m256 tan0_y_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][1][j + 8]); | |
__m256 tan0_z_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][2][j]); | |
__m256 tan0_z_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[0][2][j + 8]); | |
__m256 tan1_x_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][0][j]); | |
__m256 tan1_x_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][0][j + 8]); | |
__m256 tan1_y_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][1][j]); | |
__m256 tan1_y_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][1][j + 8]); | |
__m256 tan1_z_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][2][j]); | |
__m256 tan1_z_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_pos[1][2][j + 8]); | |
__m256 rot0_x_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[0][0][j]); | |
__m256 rot0_x_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[0][0][j + 8]); | |
__m256 rot0_y_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[0][1][j]); | |
__m256 rot0_y_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[0][1][j + 8]); | |
__m256 rot0_z_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[0][2][j]); | |
__m256 rot0_z_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[0][2][j + 8]); | |
__m256 rot1_x_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[1][0][j]); | |
__m256 rot1_x_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[1][0][j + 8]); | |
__m256 rot1_y_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[1][1][j]); | |
__m256 rot1_y_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[1][1][j + 8]); | |
__m256 rot1_z_lo = _mm256_load_ps(&decoded_data[batch_idx].rot[1][2][j]); | |
__m256 rot1_z_hi = _mm256_load_ps(&decoded_data[batch_idx].rot[1][2][j + 8]); | |
__m256 tan0_rx_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][0][j]); | |
__m256 tan0_rx_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][0][j + 8]); | |
__m256 tan0_ry_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][1][j]); | |
__m256 tan0_ry_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][1][j + 8]); | |
__m256 tan0_rz_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][2][j]); | |
__m256 tan0_rz_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[0][2][j + 8]); | |
__m256 tan1_rx_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][0][j]); | |
__m256 tan1_rx_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][0][j + 8]); | |
__m256 tan1_ry_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][1][j]); | |
__m256 tan1_ry_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][1][j + 8]); | |
__m256 tan1_rz_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][2][j]); | |
__m256 tan1_rz_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_rot[1][2][j + 8]); | |
__m256 scl0_lo = _mm256_load_ps(&decoded_data[batch_idx].scl[0][j]); | |
__m256 scl0_hi = _mm256_load_ps(&decoded_data[batch_idx].scl[0][j + 8]); | |
__m256 scl1_lo = _mm256_load_ps(&decoded_data[batch_idx].scl[1][j]); | |
__m256 scl1_hi = _mm256_load_ps(&decoded_data[batch_idx].scl[1][j + 8]); | |
__m256 tan0_s_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_scl[0][j]); | |
__m256 tan0_s_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_scl[0][j + 8]); | |
__m256 tan1_s_lo = _mm256_load_ps(&decoded_data[batch_idx].tan_scl[1][j]); | |
__m256 tan1_s_hi = _mm256_load_ps(&decoded_data[batch_idx].tan_scl[1][j + 8]); | |
// Hermite interpolation: result = h00*p0 + h01*p1 + h10*t0 + h11*t1 | |
__m256 result_x_lo = _mm256_fmadd_ps(h00, pos0_x_lo, _mm256_fmadd_ps(h01, pos1_x_lo, _mm256_fmadd_ps(h10, tan0_x_lo, _mm256_mul_ps(h11, tan1_x_lo)))); | |
__m256 result_x_hi = _mm256_fmadd_ps(h00, pos0_x_hi, _mm256_fmadd_ps(h01, pos1_x_hi, _mm256_fmadd_ps(h10, tan0_x_hi, _mm256_mul_ps(h11, tan1_x_hi)))); | |
__m256 result_y_lo = _mm256_fmadd_ps(h00, pos0_y_lo, _mm256_fmadd_ps(h01, pos1_y_lo, _mm256_fmadd_ps(h10, tan0_y_lo, _mm256_mul_ps(h11, tan1_y_lo)))); | |
__m256 result_y_hi = _mm256_fmadd_ps(h00, pos0_y_hi, _mm256_fmadd_ps(h01, pos1_y_hi, _mm256_fmadd_ps(h10, tan0_y_hi, _mm256_mul_ps(h11, tan1_y_hi)))); | |
__m256 result_z_lo = _mm256_fmadd_ps(h00, pos0_z_lo, _mm256_fmadd_ps(h01, pos1_z_lo, _mm256_fmadd_ps(h10, tan0_z_lo, _mm256_mul_ps(h11, tan1_z_lo)))); | |
__m256 result_z_hi = _mm256_fmadd_ps(h00, pos0_z_hi, _mm256_fmadd_ps(h01, pos1_z_hi, _mm256_fmadd_ps(h10, tan0_z_hi, _mm256_mul_ps(h11, tan1_z_hi)))); | |
__m256 result_scl_lo = _mm256_fmadd_ps(h00, scl0_lo, _mm256_fmadd_ps(h01, scl1_lo, _mm256_fmadd_ps(h10, tan0_s_lo, _mm256_mul_ps(h11, tan1_s_lo)))); | |
__m256 result_scl_hi = _mm256_fmadd_ps(h00, scl0_hi, _mm256_fmadd_ps(h01, scl1_hi, _mm256_fmadd_ps(h10, tan0_s_hi, _mm256_mul_ps(h11, tan1_s_hi)))); | |
__m256 result_rot_x_lo = _mm256_fmadd_ps(h00, rot0_x_lo, _mm256_fmadd_ps(h01, rot1_x_lo, _mm256_fmadd_ps(h10, tan0_rx_lo, _mm256_mul_ps(h11, tan1_rx_lo)))); | |
__m256 result_rot_x_hi = _mm256_fmadd_ps(h00, rot0_x_hi, _mm256_fmadd_ps(h01, rot1_x_hi, _mm256_fmadd_ps(h10, tan0_rx_hi, _mm256_mul_ps(h11, tan1_rx_hi)))); | |
__m256 result_rot_y_lo = _mm256_fmadd_ps(h00, rot0_y_lo, _mm256_fmadd_ps(h01, rot1_y_lo, _mm256_fmadd_ps(h10, tan0_ry_lo, _mm256_mul_ps(h11, tan1_ry_lo)))); | |
__m256 result_rot_y_hi = _mm256_fmadd_ps(h00, rot0_y_hi, _mm256_fmadd_ps(h01, rot1_y_hi, _mm256_fmadd_ps(h10, tan0_ry_hi, _mm256_mul_ps(h11, tan1_ry_hi)))); | |
__m256 result_rot_z_lo = _mm256_fmadd_ps(h00, rot0_z_lo, _mm256_fmadd_ps(h01, rot1_z_lo, _mm256_fmadd_ps(h10, tan0_rz_lo, _mm256_mul_ps(h11, tan1_rz_lo)))); | |
__m256 result_rot_z_hi = _mm256_fmadd_ps(h00, rot0_z_hi, _mm256_fmadd_ps(h01, rot1_z_hi, _mm256_fmadd_ps(h10, tan0_rz_hi, _mm256_mul_ps(h11, tan1_rz_hi)))); | |
// Stream to output (layout: x, y, z, scl, rot_x, rot_y, rot_z per joint) | |
_mm256_stream_ps(out_ptr + j + 0 * 256, result_x_lo); | |
_mm256_stream_ps(out_ptr + j + 8 + 0 * 256, result_x_hi); | |
_mm256_stream_ps(out_ptr + j + 1 * 256, result_y_lo); | |
_mm256_stream_ps(out_ptr + j + 8 + 1 * 256, result_y_hi); | |
_mm256_stream_ps(out_ptr + j + 2 * 256, result_z_lo); | |
_mm256_stream_ps(out_ptr + j + 8 + 2 * 256, result_z_hi); | |
_mm256_stream_ps(out_ptr + j + 3 * 256, result_scl_lo); | |
_mm256_stream_ps(out_ptr + j + 8 + 3 * 256, result_scl_hi); | |
_mm256_stream_ps(out_ptr + j + 4 * 256, result_rot_x_lo); | |
_mm256_stream_ps(out_ptr + j + 8 + 4 * 256, result_rot_x_hi); | |
_mm256_stream_ps(out_ptr + j + 5 * 256, result_rot_y_lo); | |
_mm256_stream_ps(out_ptr + j + 8 + 5 * 256, result_rot_y_hi); | |
_mm256_stream_ps(out_ptr + j + 6 * 256, result_rot_z_lo); | |
_mm256_stream_ps(out_ptr + j + 8 + 6 * 256, result_rot_z_hi); | |
} | |
out_ptr += 256 * ANIM_TRK_ELM_CNT; // Advance to next mdl’s output | |
} | |
} | |
static void | |
anim_update(float* out, int mdl_start, int mdl_cnt) { | |
align32 struct anm_frame decoded_data[BATCH_SIZE]; // 112 KB (28 KB * 4) move to heap | |
int remainder = (mdl_cnt % BATCH_SIZE); | |
int batch_cnt = mdl_cnt / BATCH_SIZE + !!remainder; | |
for (int b = 0; b < batch_cnt; b++) { | |
align32 float t_scalar[BATCH_SIZE] = {0}; // 16 B | |
align32 float h_coeffs[BATCH_SIZE * 4] = {0}; // 64 B: h00, h01, h10, h11 | |
int start = mdl_start + b * BATCH_SIZE; | |
int cnt = (start + BATCH_SIZE <= mdl_start + mdl_cnt) ? BATCH_SIZE : remainder; | |
anim_decode_frames(decoded_data, start, cnt); | |
compute_hermite_coeffs(t_scalar, h_coeffs, start, cnt); | |
anim_interpolate_frames(decoded_data, out + start * 256 * ANIM_TRK_ELM_CNT, start, cnt, h_coeffs); | |
} | |
_mm_sfence(); // Ensure all streaming stores complete | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* --------------------------------------------------------------------------- | |
* Thread Pool | |
* --------------------------------------------------------------------------- | |
*/ | |
#define MAX_THREADS 128 | |
#define QUEUE_SIZE 1024 | |
#define BATCH_SIZE 4 | |
// tsk structure | |
struct thrd_pool_job { | |
float* out; // Output buffer for this chunk | |
int start; // Starting model index | |
int cnt; // Number of models to process | |
int valid; // Is this tsk active? | |
}; | |
// Thread pool structure | |
struct thrd_pool { | |
pthread_t threads[MAX_THREADS]; | |
struct thrd_pool_job queue[QUEUE_SIZE]; | |
int thread_cnt; | |
int tsk_cnt; | |
int tsks_done; | |
int shutdown; | |
pthread_mutex_t mutex; | |
pthread_cond_t cond_not_empty; | |
pthread_cond_t cond_done; | |
}; | |
static struct thrd_pool pool; | |
// Worker thread function | |
static void* | |
thrd_wrkr(void* arg) { | |
struct thrd_pool* pool = (struct thrd_pool*)arg; | |
while (1) { | |
struct thrd_pool_job tsk; | |
int got_tsk = 0; | |
pthread_mutex_lock(&pool->mutex); | |
while (!pool->shutdown && pool->tsk_cnt == 0) { | |
pthread_cond_wait(&pool->cond_not_empty, &pool->mutex); | |
} | |
if (pool->shutdown && pool->tsk_cnt == 0) { | |
pthread_mutex_unlock(&pool->mutex); | |
break; | |
} | |
if (pool->tsk_cnt > 0) { | |
tsk = pool->queue[0]; | |
for (int i = 1; i < pool->tsk_cnt; i++) { | |
pool->queue[i - 1] = pool->queue[i]; | |
} | |
pool->tsk_cnt--; | |
got_tsk = 1; | |
} | |
pthread_mutex_unlock(&pool->mutex); | |
if (got_tsk) { | |
anim_update(tsk.out, tsk.start, tsk.cnt); | |
pthread_mutex_lock(&pool->mutex); | |
pool->tsks_done++; | |
if (pool->tsks_done == pool->tsk_cnt + pool->tsks_done) { | |
pthread_cond_broadcast(&pool->cond_done); | |
} | |
pthread_mutex_unlock(&pool->mutex); | |
} | |
} | |
return 0; | |
} | |
static void | |
thrd_pool_init(int num_threads) { | |
num_threads = (num_threads > MAX_THREADS) ? MAX_THREADS : num_threads; | |
pthread_mutex_init(&pool.mutex, 0); | |
pthread_cond_init(&pool.cond_not_empty, 0); | |
pthread_cond_init(&pool.cond_done, 0); | |
pool.thread_cnt = num_threads; | |
pool.tsk_cnt = 0; | |
pool.tsks_done = 0; | |
pool.shutdown = 0; | |
for (int i = 0; i < num_threads; i++) { | |
pthread_create(&pool.threads[i], 0, thrd_wrkr, &pool); | |
} | |
} | |
static void | |
thrd_pool_submit(float* out, int start, int cnt) { | |
pthread_mutex_lock(&pool.mutex); | |
if (pool.tsk_cnt < QUEUE_SIZE) { | |
struct thrd_pool_job tsk = {out, start, cnt, 1}; | |
pool->queue[pool->tsk_cnt++] = tsk; | |
pthread_cond_signal(&pool->cond_not_empty); | |
} else { | |
fprintf(stderr, "tsk queue full!\n"); | |
} | |
pthread_mutex_unlock(&pool->mutex); | |
} | |
static void | |
thrd_pool_wait(void) { | |
pthread_mutex_lock(&pool.mutex); | |
while (pool.tsk_cnt > 0 || pool.tsks_done < pool.tsk_cnt + pool.tsks_done) { | |
pthread_cond_wait(&pool->cond_done, &pool->mutex); | |
} | |
pool.tsks_done = 0; | |
pthread_mutex_unlock(&pool->mutex); | |
} | |
static void | |
thrd_pool_shutdown(void) { | |
pthread_mutex_lock(&pool.mutex); | |
pool.shutdown = 1; | |
pthread_cond_broadcast(&pool->cond_not_empty); | |
pthread_mutex_unlock(&pool->mutex); | |
for (int i = 0; i < pool.thread_cnt; i++) { | |
pthread_join(pool.threads[i], 0); | |
} | |
pthread_mutex_destroy(&pool.mutex); | |
pthread_cond_destroy(&pool.cond_not_empty); | |
pthread_cond_destroy(&pool.cond_done); | |
} | |
// Multi-threaded anim_update wrapper with batch alignment | |
static void | |
job_anim_update(float* out, int total_models) { | |
int cores = pool.thread_cnt; | |
int base_models_per_thread = total_models / cores; // Integer division | |
int remainder = total_models % cores; // Leftover models | |
int processed_models = 0; | |
for (int i = 0; i < cores; i++) { | |
int start = processed_models; | |
int cnt = base_models_per_thread + (i < remainder ? 1 : 0); // Distribute remainder | |
if (cnt > 0) { | |
float* chunk_out = out + start * 256 * 7; | |
// Adjust cnt to be multiple of BATCH_SIZE, handle remainder later if needed | |
int aligned_cnt = (cnt / BATCH_SIZE) * BATCH_SIZE; | |
if (aligned_cnt > 0) { | |
thrd_pool_submit(chunk_out, start, aligned_cnt); | |
} | |
processed_models += aligned_cnt; | |
} | |
} | |
// Handle remaining models (not aligned to BATCH_SIZE) | |
int leftover = total_models - processed_models; | |
if (leftover > 0) { | |
float* chunk_out = out + processed_models * 256 * 7; | |
thrd_pool_submit(chunk_out, processed_models, leftover); | |
} | |
thrd_pool_wait(); | |
} | |
extern int | |
main(void) { | |
thrd_pool_init(8); // Example: 8 cores | |
int total_models = 77601; // Not multiple of 4 or 8 | |
float* out = malloc(total_models * 256 * 7 * sizeof(float)); | |
job_anim_update(out, total_models); | |
free(out); | |
thrd_pool_shutdown(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment