[LeetCode] 327. Count of Range Sum

Solution

A similar question is Leetcode 493, counting reverse pairs. A data structure “Binary Indexed Tree”, BIT, would help.

Steps:

  1. Calculate prefix sum. -> sum[]
  2. Work out all lower bounds and upper bounds for each sum element. -> lowerBounds[], upperBounds[]
  3. all[] = unique(sum[] + lowerBounds[] + upperBounds[])
  4. discretization
  5. For each element sum[i], count the times of elements between lowerBounds[i] and upperBounds[i] occurred.

Code

class Solution {
    int N;
    int[] c;
    long[] sum;
    long[] lowerBound;
    long[] upperBound;
    long[] all;
    private int low_bit(int x) {
        return x & -x;
    }
    
    private void add(int x, int v) {
        while (x <= N) {
            c[x] += v;
            x += low_bit(x);
        }
    }
    
    private int query(int x) {
        int ans = 0;
        while (x > 0) {
            ans += c[x];
            x -= low_bit(x);
        }
        return ans;
    }
    
    public int countRangeSum(int[] nums, int lower, int upper) {
        if (nums.length == 0) return 0;
        prepare(nums, lower, upper);
        //discretization
        HashMap<Long, Integer> map = new HashMap();
        for(int i = 0; i < all.length; i++) map.put(all[i], i + 1);
        int ans = 0;
        add(map.get((long)0), 1);
        for(int i = 1; i < sum.length; i++) {
            ans += query(map.get(upperBound[i])) - query(map.get(lowerBound[i])-1);
            add(map.get(sum[i]), 1);
        }
        return ans;
        
    }
    private void prepare(int[] nums, int lower, int upper) {
        sum = new long[nums.length + 1];
        lowerBound = new long[nums.length + 1];
        upperBound = new long[nums.length + 1];
        //unique elements
        HashSet<Long> allSet = new HashSet();
        allSet.add( (long)0);
        for (int i = 1; i <= nums.length; i++) {
            sum[i] = sum[i-1] + nums[i-1];
            allSet.add(sum[i]);
            allSet.add(sum[i] - lower);
            allSet.add(sum[i] - upper);
            lowerBound[i] = sum[i] - upper;
            upperBound[i] = sum[i] - lower;
        }
        all = new long[allSet.size()];
        this.c = new int[allSet.size() + 1];
        this.N = allSet.size();
        int idx = 0;
        for(long x : allSet) {
            all[idx++] = x;
        }
        Arrays.sort(all);
    }
}

Author: huadonghu

Leave a Reply

Your email address will not be published. Required fields are marked *