Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save SuryaPratapK/edbdb2b52de8809ae0e2d36245ad0f0f to your computer and use it in GitHub Desktop.
Save SuryaPratapK/edbdb2b52de8809ae0e2d36245ad0f0f to your computer and use it in GitHub Desktop.
class Solution {
using ll = long long;
vector<ll> seg_tree;
void updateSegTree(ll st_idx,ll start,ll end,ll& query_idx){
if(end<query_idx or start>query_idx)//Case-1: No Overlap
return;
if(start==end){//Case-2: Total Overlap
seg_tree[st_idx]++;
return;
}
//Case-3: Partial Overlap
ll mid = start + (end-start)/2;
updateSegTree(2*st_idx,start,mid,query_idx);
updateSegTree(2*st_idx+1,mid+1,end,query_idx);
seg_tree[st_idx] = seg_tree[2*st_idx] + seg_tree[2*st_idx+1];
}
int rangeSumQuery(ll st_idx,ll start,ll end,ll qs,ll qe){
if(qs>end or qe<start) return 0;//Case-1: No overlap
if(start>=qs and end<=qe) return seg_tree[st_idx];//Case-2: Total Overlap
//Case-3: Partial Overlap
ll mid = start + (end-start)/2;
ll left_sum = rangeSumQuery(2*st_idx,start,mid,qs,qe);
ll right_sum = rangeSumQuery(2*st_idx+1,mid+1,end,qs,qe);
return left_sum + right_sum;
}
public:
long long goodTriplets(vector<int>& nums1, vector<int>& nums2) {
//Step-1: Define Segment Tree and save val->idx for elements in nums2
ll n = nums1.size();
seg_tree = vector<ll>(4*n+1,0);
unordered_map<ll,ll> nums2_val_idx;
for(ll i=0;i<n;++i)
nums2_val_idx[nums2[i]] = i;//All elements are unique hence no collision on key
//Step-2: Push the leftmost nums1 item in Segment Tree & Process all items from idx 1
updateSegTree(1,0,n-1,nums2_val_idx[nums1[0]]);
ll count_good_triplets = 0;
for(int i=1;i<n-1;++i){//Consider i as the middle item of triplet
ll nums2_idx = nums2_val_idx[nums1[i]];
ll common_left_elements = rangeSumQuery(1,0,n-1,0,nums2_idx);
ll uncommon_left_elements = (i-common_left_elements);
ll common_right_elements = (n-nums2_idx-1) - uncommon_left_elements;
count_good_triplets += common_left_elements * common_right_elements;
updateSegTree(1,0,n-1,nums2_idx);
}
return count_good_triplets;
}
};
/*
//JAVA
import java.util.HashMap;
import java.util.Map;
class Solution {
private long[] segTree;
private void updateSegTree(int stIdx, int start, int end, int queryIdx) {
if (end < queryIdx || start > queryIdx) return;
if (start == end) {
segTree[stIdx]++;
return;
}
int mid = start + (end - start) / 2;
updateSegTree(2 * stIdx, start, mid, queryIdx);
updateSegTree(2 * stIdx + 1, mid + 1, end, queryIdx);
segTree[stIdx] = segTree[2 * stIdx] + segTree[2 * stIdx + 1];
}
private long rangeSumQuery(int stIdx, int start, int end, int qs, int qe) {
if (qs > end || qe < start) return 0;
if (start >= qs && end <= qe) return segTree[stIdx];
int mid = start + (end - start) / 2;
long leftSum = rangeSumQuery(2 * stIdx, start, mid, qs, qe);
long rightSum = rangeSumQuery(2 * stIdx + 1, mid + 1, end, qs, qe);
return leftSum + rightSum;
}
public long goodTriplets(int[] nums1, int[] nums2) {
int n = nums1.length;
segTree = new long[4 * n + 1];
Map<Integer, Integer> nums2ValIdx = new HashMap<>();
for (int i = 0; i < n; i++) {
nums2ValIdx.put(nums2[i], i);
}
updateSegTree(1, 0, n - 1, nums2ValIdx.get(nums1[0]));
long countGoodTriplets = 0;
for (int i = 1; i < n - 1; i++) {
int nums2Idx = nums2ValIdx.get(nums1[i]);
long commonLeftElements = rangeSumQuery(1, 0, n - 1, 0, nums2Idx);
long uncommonLeftItems = i - commonLeftElements;
long commonRightElements = (n - nums2Idx - 1) - uncommonLeftItems;
countGoodTriplets += commonLeftElements * commonRightElements;
updateSegTree(1, 0, n - 1, nums2Idx);
}
return countGoodTriplets;
}
}
#Python
class Solution:
def goodTriplets(self, nums1: List[int], nums2: List[int]) -> int:
n = len(nums1)
self.seg_tree = [0] * (4 * n + 1)
nums2_val_idx = {val: idx for idx, val in enumerate(nums2)}
def update_seg_tree(st_idx, start, end, query_idx):
if end < query_idx or start > query_idx:
return
if start == end:
self.seg_tree[st_idx] += 1
return
mid = start + (end - start) // 2
update_seg_tree(2 * st_idx, start, mid, query_idx)
update_seg_tree(2 * st_idx + 1, mid + 1, end, query_idx)
self.seg_tree[st_idx] = self.seg_tree[2 * st_idx] + self.seg_tree[2 * st_idx + 1]
def range_sum_query(st_idx, start, end, qs, qe):
if qs > end or qe < start:
return 0
if start >= qs and end <= qe:
return self.seg_tree[st_idx]
mid = start + (end - start) // 2
left_sum = range_sum_query(2 * st_idx, start, mid, qs, qe)
right_sum = range_sum_query(2 * st_idx + 1, mid + 1, end, qs, qe)
return left_sum + right_sum
update_seg_tree(1, 0, n - 1, nums2_val_idx[nums1[0]])
count_good_triplets = 0
for i in range(1, n - 1):
nums2_idx = nums2_val_idx[nums1[i]]
common_left_elements = range_sum_query(1, 0, n - 1, 0, nums2_idx)
uncommon_left_items = i - common_left_elements
common_right_elements = (n - nums2_idx - 1) - uncommon_left_items
count_good_triplets += common_left_elements * common_right_elements
update_seg_tree(1, 0, n - 1, nums2_idx)
return count_good_triplets
*/
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment