Skip to content

Instantly share code, notes, and snippets.

@komakai
Last active April 14, 2025 12:56
Show Gist options
  • Save komakai/36534387490ba9f7f7ce595cc0e2c0a1 to your computer and use it in GitHub Desktop.
Save komakai/36534387490ba9f7f7ce595cc0e2c0a1 to your computer and use it in GitHub Desktop.
Card Perms
// Red-black tree with subtree sizes for O(log(N)) calculation of order within the set
template <typename T>
class RedBlackTree {
public:
RedBlackTree() {
root = NIL = new Node();
NIL->parent = NIL->left = NIL->right = NIL;
}
~RedBlackTree() {
destroyTree(root);
delete NIL;
}
int insert(const T& data, bool dryRun = false) {
Node *y = NIL, *x = root;
int order = 0;
while (x != NIL) {
y = x;
x->subtree_size++;
order += (data < x->data) ? 0 : (x->left != NIL ? (x->left->subtree_size + 1) : 1);
x = x->getSubNode(data < x->data);
}
if (!dryRun) {
Node* z = new Node(data);
z->left = z->right = NIL;
z->parent = y;
if (y == NIL) {
root = z;
} else {
y->getSubNode(z->data < y->data) = z;
}
insertFixup(z);
}
return order;
}
private:
enum class Color { RED, BLACK };
struct Node {
T data;
Color color;
Node* parent;
Node* left;
Node* right;
int subtree_size;
Node*& getSubNode(bool isLeft) { return isLeft ? left : right; }
Node() : data(T()), color(Color::BLACK), parent(nullptr), left(nullptr), right(nullptr), subtree_size(0) {}
Node(const T& data) : data(data), color(Color::RED), parent(nullptr), left(nullptr), right(nullptr), subtree_size(1) {}
};
Node* root;
Node* NIL;
void rotate(Node* x, bool isLeft = true) {
Node* y = x->getSubNode(!isLeft);
x->getSubNode(!isLeft) = y->getSubNode(isLeft);
if (y->getSubNode(isLeft) != NIL) {
y->getSubNode(isLeft)->parent = x;
}
y->parent = x->parent;
if (x->parent == NIL) {
root = y;
} else {
x->parent->getSubNode(x == x->parent->getSubNode(true)) = y;
}
y->getSubNode(isLeft) = x;
x->parent = y;
updateSubtreeSize(x);
updateSubtreeSize(y);
}
void insertFixup(Node* z) {
while (z->parent->color == Color::RED) {
bool isLeft = (z->parent == z->parent->parent->left);
Node* y = z->parent->parent->getSubNode(!isLeft);
if (y->color == Color::RED) {
z->parent->color = Color::BLACK;
y->color = Color::BLACK;
z->parent->parent->color = Color::RED;
z = z->parent->parent;
} else {
if (z == z->parent->getSubNode(!isLeft)) {
z = z->parent;
rotate(z, isLeft);
}
z->parent->color = Color::BLACK;
z->parent->parent->color = Color::RED;
rotate(z->parent->parent, !isLeft);
}
}
root->color = Color::BLACK;
}
void updateSubtreeSize(Node* node) {
if (node == NIL) return;
node->subtree_size = 1;
node->subtree_size += node->left->subtree_size;
node->subtree_size += node->right->subtree_size;
}
void destroyTree(Node* node) {
if (node != NIL) {
destroyTree(node->left);
destroyTree(node->right);
delete node;
}
}
};
// template for arithmetic modulo N
template <int N> class intModuloN {
public:
intModuloN(unsigned long i) {
this->n = i % (unsigned long)N;
}
intModuloN<N>& operator+=(const intModuloN<N>& rhs) {
n = ((unsigned long)n + (unsigned long)rhs.n) % (unsigned long)N;
return *this;
}
intModuloN<N>& operator++() {
n = ((long)n + 1) % (unsigned long)N;
return *this;
}
intModuloN<N>& operator--() {
n = ((long)n + N - 1) % (unsigned long)N;
return *this;
}
operator int() const { return n; }
private:
int n;
};
template <int N>
intModuloN<N> operator+(const intModuloN<N>& lhs, const intModuloN<N>& rhs) {
return intModuloN<N>((unsigned long)(int)lhs + (unsigned long)(int)rhs);
}
template <int N>
intModuloN<N> operator*(const intModuloN<N>& lhs, const intModuloN<N>& rhs) {
return intModuloN<N>((unsigned long)(int)lhs * (unsigned long)(int)rhs);
}
const int D = 1000000007;
typedef intModuloN<D> modInt;
long solve(vector<int> x) {
int n = (int)x.size();
// pre-calculate factorial(x) and factorial(x)/2 modulo N
vector<modInt> permsCache { 1, 1, 2 }, permsDiv2Cache { 0, 0, 1 };
for (int i = 3; i <= n; i++) {
permsCache.push_back(modInt(i) * permsCache[i-1]);
permsDiv2Cache.push_back(modInt(i) * permsDiv2Cache[i-1]);
}
int zeros = 0; // total number of 0 elements
set<int> zeroOffsets; // offsets of the zero elements
int offset = 0;
vector<int> xRev = x; // elements reversed
reverse(xRev.begin(), xRev.end());
vector<bool> tmp(n);
fill(tmp.begin(), tmp.begin() + n, false);
for (auto i: xRev) {
if (i == 0) {
zeros++;
zeroOffsets.insert(offset);
} else {
tmp[i - 1] = true;
}
offset++;
}
set<int> missingValues; // values with unknown position
for (int i = 0; i < n; i++) {
if (!tmp[i]) {
missingValues.insert(i+1);
}
}
auto cmp = [](pair<int, int> a, pair<int, int> b) { return a.first < b.first; };
set<pair<int, int>, decltype(cmp)> gaps(cmp); // set of pairs (value, total number of values with unknown position greater than or equal to first)
int missingCount = (int)missingValues.size();
auto missIt = missingValues.begin();
while (missIt != missingValues.end()) {
int runStart = *missIt;
int runLatest = runStart;
int weight = missingCount;
while (missIt != missingValues.end() && *missIt == runLatest) {
--missingCount;
missIt++;
runLatest++;
}
gaps.insert(make_pair(runStart, weight));
}
gaps.insert(make_pair(n + 1, 0));
RedBlackTree<int> setR;
modInt sumRs = 0;
if (xRev[0] != 0) {
setR.insert(xRev[0]);
auto insertPoint = gaps.upper_bound(make_pair(xRev[0], 0));
sumRs += modInt((*insertPoint).second);
}
modInt d = 0;
modInt answer = permsCache[zeros];
for (int i = 1; i < n; i++) {
if (xRev[i-1] == 0) ++d;
if (zeroOffsets.find(i) == zeroOffsets.end() || zeros == 1) {
if (xRev[i] != 0) sumRs += modInt((*gaps.upper_bound(make_pair(xRev[i], 0))).second);
modInt r = setR.insert(xRev[i] != 0 ? xRev[i] : *missingValues.begin(), xRev[i] == 0);
auto insertPoint = gaps.upper_bound(make_pair(xRev[i], 0));
modInt K = missingValues.size() - (*insertPoint).second;
answer += ((zeros == 0) ? r : (((d * K) + (r * modInt(zeros))) * permsCache[zeros - 1])) * permsCache[i];
} else {
answer += ((d * permsDiv2Cache[zeros]) + (sumRs * permsCache[zeros - 1])) * permsCache[i];
}
}
return answer;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment