Last active
April 14, 2025 12:56
-
-
Save komakai/36534387490ba9f7f7ce595cc0e2c0a1 to your computer and use it in GitHub Desktop.
Card Perms
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Red-black tree with subtree sizes for O(log(N)) calculation of order within the set | |
template <typename T> | |
class RedBlackTree { | |
public: | |
RedBlackTree() { | |
root = NIL = new Node(); | |
NIL->parent = NIL->left = NIL->right = NIL; | |
} | |
~RedBlackTree() { | |
destroyTree(root); | |
delete NIL; | |
} | |
int insert(const T& data, bool dryRun = false) { | |
Node *y = NIL, *x = root; | |
int order = 0; | |
while (x != NIL) { | |
y = x; | |
x->subtree_size++; | |
order += (data < x->data) ? 0 : (x->left != NIL ? (x->left->subtree_size + 1) : 1); | |
x = x->getSubNode(data < x->data); | |
} | |
if (!dryRun) { | |
Node* z = new Node(data); | |
z->left = z->right = NIL; | |
z->parent = y; | |
if (y == NIL) { | |
root = z; | |
} else { | |
y->getSubNode(z->data < y->data) = z; | |
} | |
insertFixup(z); | |
} | |
return order; | |
} | |
private: | |
enum class Color { RED, BLACK }; | |
struct Node { | |
T data; | |
Color color; | |
Node* parent; | |
Node* left; | |
Node* right; | |
int subtree_size; | |
Node*& getSubNode(bool isLeft) { return isLeft ? left : right; } | |
Node() : data(T()), color(Color::BLACK), parent(nullptr), left(nullptr), right(nullptr), subtree_size(0) {} | |
Node(const T& data) : data(data), color(Color::RED), parent(nullptr), left(nullptr), right(nullptr), subtree_size(1) {} | |
}; | |
Node* root; | |
Node* NIL; | |
void rotate(Node* x, bool isLeft = true) { | |
Node* y = x->getSubNode(!isLeft); | |
x->getSubNode(!isLeft) = y->getSubNode(isLeft); | |
if (y->getSubNode(isLeft) != NIL) { | |
y->getSubNode(isLeft)->parent = x; | |
} | |
y->parent = x->parent; | |
if (x->parent == NIL) { | |
root = y; | |
} else { | |
x->parent->getSubNode(x == x->parent->getSubNode(true)) = y; | |
} | |
y->getSubNode(isLeft) = x; | |
x->parent = y; | |
updateSubtreeSize(x); | |
updateSubtreeSize(y); | |
} | |
void insertFixup(Node* z) { | |
while (z->parent->color == Color::RED) { | |
bool isLeft = (z->parent == z->parent->parent->left); | |
Node* y = z->parent->parent->getSubNode(!isLeft); | |
if (y->color == Color::RED) { | |
z->parent->color = Color::BLACK; | |
y->color = Color::BLACK; | |
z->parent->parent->color = Color::RED; | |
z = z->parent->parent; | |
} else { | |
if (z == z->parent->getSubNode(!isLeft)) { | |
z = z->parent; | |
rotate(z, isLeft); | |
} | |
z->parent->color = Color::BLACK; | |
z->parent->parent->color = Color::RED; | |
rotate(z->parent->parent, !isLeft); | |
} | |
} | |
root->color = Color::BLACK; | |
} | |
void updateSubtreeSize(Node* node) { | |
if (node == NIL) return; | |
node->subtree_size = 1; | |
node->subtree_size += node->left->subtree_size; | |
node->subtree_size += node->right->subtree_size; | |
} | |
void destroyTree(Node* node) { | |
if (node != NIL) { | |
destroyTree(node->left); | |
destroyTree(node->right); | |
delete node; | |
} | |
} | |
}; | |
// template for arithmetic modulo N | |
template <int N> class intModuloN { | |
public: | |
intModuloN(unsigned long i) { | |
this->n = i % (unsigned long)N; | |
} | |
intModuloN<N>& operator+=(const intModuloN<N>& rhs) { | |
n = ((unsigned long)n + (unsigned long)rhs.n) % (unsigned long)N; | |
return *this; | |
} | |
intModuloN<N>& operator++() { | |
n = ((long)n + 1) % (unsigned long)N; | |
return *this; | |
} | |
intModuloN<N>& operator--() { | |
n = ((long)n + N - 1) % (unsigned long)N; | |
return *this; | |
} | |
operator int() const { return n; } | |
private: | |
int n; | |
}; | |
template <int N> | |
intModuloN<N> operator+(const intModuloN<N>& lhs, const intModuloN<N>& rhs) { | |
return intModuloN<N>((unsigned long)(int)lhs + (unsigned long)(int)rhs); | |
} | |
template <int N> | |
intModuloN<N> operator*(const intModuloN<N>& lhs, const intModuloN<N>& rhs) { | |
return intModuloN<N>((unsigned long)(int)lhs * (unsigned long)(int)rhs); | |
} | |
const int D = 1000000007; | |
typedef intModuloN<D> modInt; | |
long solve(vector<int> x) { | |
int n = (int)x.size(); | |
// pre-calculate factorial(x) and factorial(x)/2 modulo N | |
vector<modInt> permsCache { 1, 1, 2 }, permsDiv2Cache { 0, 0, 1 }; | |
for (int i = 3; i <= n; i++) { | |
permsCache.push_back(modInt(i) * permsCache[i-1]); | |
permsDiv2Cache.push_back(modInt(i) * permsDiv2Cache[i-1]); | |
} | |
int zeros = 0; // total number of 0 elements | |
set<int> zeroOffsets; // offsets of the zero elements | |
int offset = 0; | |
vector<int> xRev = x; // elements reversed | |
reverse(xRev.begin(), xRev.end()); | |
vector<bool> tmp(n); | |
fill(tmp.begin(), tmp.begin() + n, false); | |
for (auto i: xRev) { | |
if (i == 0) { | |
zeros++; | |
zeroOffsets.insert(offset); | |
} else { | |
tmp[i - 1] = true; | |
} | |
offset++; | |
} | |
set<int> missingValues; // values with unknown position | |
for (int i = 0; i < n; i++) { | |
if (!tmp[i]) { | |
missingValues.insert(i+1); | |
} | |
} | |
auto cmp = [](pair<int, int> a, pair<int, int> b) { return a.first < b.first; }; | |
set<pair<int, int>, decltype(cmp)> gaps(cmp); // set of pairs (value, total number of values with unknown position greater than or equal to first) | |
int missingCount = (int)missingValues.size(); | |
auto missIt = missingValues.begin(); | |
while (missIt != missingValues.end()) { | |
int runStart = *missIt; | |
int runLatest = runStart; | |
int weight = missingCount; | |
while (missIt != missingValues.end() && *missIt == runLatest) { | |
--missingCount; | |
missIt++; | |
runLatest++; | |
} | |
gaps.insert(make_pair(runStart, weight)); | |
} | |
gaps.insert(make_pair(n + 1, 0)); | |
RedBlackTree<int> setR; | |
modInt sumRs = 0; | |
if (xRev[0] != 0) { | |
setR.insert(xRev[0]); | |
auto insertPoint = gaps.upper_bound(make_pair(xRev[0], 0)); | |
sumRs += modInt((*insertPoint).second); | |
} | |
modInt d = 0; | |
modInt answer = permsCache[zeros]; | |
for (int i = 1; i < n; i++) { | |
if (xRev[i-1] == 0) ++d; | |
if (zeroOffsets.find(i) == zeroOffsets.end() || zeros == 1) { | |
if (xRev[i] != 0) sumRs += modInt((*gaps.upper_bound(make_pair(xRev[i], 0))).second); | |
modInt r = setR.insert(xRev[i] != 0 ? xRev[i] : *missingValues.begin(), xRev[i] == 0); | |
auto insertPoint = gaps.upper_bound(make_pair(xRev[i], 0)); | |
modInt K = missingValues.size() - (*insertPoint).second; | |
answer += ((zeros == 0) ? r : (((d * K) + (r * modInt(zeros))) * permsCache[zeros - 1])) * permsCache[i]; | |
} else { | |
answer += ((d * permsDiv2Cache[zeros]) + (sumRs * permsCache[zeros - 1])) * permsCache[i]; | |
} | |
} | |
return answer; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment