Day 34: Segment Trees Advanced

Day 34: Segment Trees Advanced

Welcome to Day 34 of my 100 Days of DSA challenge! Today, I delved into advanced Segment Tree concepts, focusing on optimizing range updates and queries using lazy propagation. These advanced techniques allow for efficient solutions to problems involving range sums, range updates, and even managing distinct elements within a range.

Check out my GitHub repository for all the solutions and progress updates at: 100 Days of DSA

Let’s dive into the challenges I tackled today. 🚀


1. Lazy Propagation Mechanism for Range Updates in a Segment Tree

This program implements a segment tree with lazy propagation. It allows efficient range updates and queries on an array.

  1. Propagation (propagate): Before processing any node, it ensures that all the pending updates are applied by checking the lazy array and propagating the updates down to child nodes if needed.

  2. Range Update (update_range): This function adds a value to a specific range [l, r] by updating the segment tree and marking the affected segments in the lazy array for future propagation.

  3. Query Range (query_range): It calculates the sum of the elements in a given range [l, r], ensuring that any pending updates are applied before summing the values.

  4. Initialization (init): Initializes the segment tree and lazy array to 0 for a given number of elements.

Code:

#include <iostream>
using namespace std;

#define MAX 1000

class SegmentTree {
    private:
        int tree[MAX * 4];     // Segment tree array
        int lazy[MAX * 4];     // Lazy array for range updates

        // Helper function to propagate the updates
        void propagate(int node, int start, int end) {
            if (lazy[node] != 0) {
                tree[node] += lazy[node];               // Apply the pending update to the current node
                if (start != end) {                     // Not a leaf node
                    lazy[2 * node + 1] += lazy[node];   // Mark the left child for lazy propagation
                    lazy[2 * node + 2] += lazy[node];   // Mark the right child for lazy propagation
                }
                lazy[node] = 0;     // Clear the lazy value after propagation
            }
        }

        // Function to update the range [l, r] with the value val
        void update_range(int node, int start, int end, int l, int r, int val) {
            propagate(node, start, end);    // Apply any pending updates before processing the current range

            if (start > end || start > r || end < l)    // No overlap
                return;

            if (start >= l && end <= r) {   // Total overlap
                tree[node] += val;
                if (start != end) {
                    lazy[2 * node + 1] += val;      // Mark the left child for lazy propagation
                    lazy[2 * node + 2] += val;      // Mark the right child for lazy propagation
                }
                return;
            }

            // Partial overlap: Recurse for left and right children
            int mid = (start + end) / 2;
            update_range(2 * node + 1, start, mid, l, r, val);
            update_range(2 * node + 2, mid + 1, end, l, r, val);
            tree[node] = tree[2 * node + 1] + tree[2 * node + 2];   // Update the current node after the changes
        }

        // Function to query the sum in the range [l, r]
        int query_range(int node, int start, int end, int l, int r) {
            propagate(node, start, end);    // Apply any pending updates before processing the current range

            if (start > end || start > r || end < l)    // No overlap
                return 0;

            if (start >= l && end <= r) {   // Total overlap
                return tree[node];
            }

            // Partial overlap: Recurse for left and right children
            int mid = (start + end) / 2;
            int left_sum = query_range(2 * node + 1, start, mid, l, r);
            int right_sum = query_range(2 * node + 2, mid + 1, end, l, r);
            return left_sum + right_sum;
        }

    public:
        // Function to initialize the segment tree
        void init(int n) {
            for (int i = 0; i < n * 4; i++) {
                tree[i] = 0;
                lazy[i] = 0;
            }
        }

        // Wrapper function to update the range
        void update(int l, int r, int val, int n) {
            update_range(0, 0, n - 1, l, r, val);
        }

        // Wrapper function to query the range sum
        int query(int l, int r, int n) {
            return query_range(0, 0, n - 1, l, r);
        }
};

int main() {
    SegmentTree st;
    int n = 10;  
    st.init(n);
    st.update(2, 5, 3, n);  
    st.update(0, 3, 2, n);
    cout << "Sum of elements from index 1 to 4: " << st.query(1, 4, n) << endl;  
    st.update(4, 8, 5, n);  
    cout << "Sum of elements from index 3 to 7: " << st.query(3, 7, n) << endl;  
    return 0;
}

Output:


2. Range Add and Range Sum Query

This program implements a segment tree with lazy propagation to efficiently handle range updates and range sum queries. The tree stores the sum of elements in ranges, and the lazy array helps in postponing updates to specific ranges until they are needed (propagated). The update function applies an increment to all elements within a specified range, and the query function returns the sum of elements in a given range, considering any pending updates.

Code:

#include <iostream>
#include <cstring>
using namespace std;

#define MAX 1000

class SegmentTree {
    private:
        int tree[MAX * 4];     // Segment Tree array
        int lazy[MAX * 4];     // Lazy propagation array

        // Helper function to propagate the lazy value
        void propagate(int node, int start, int end) {
            if (lazy[node] != 0) {
                tree[node] += lazy[node] * (end - start + 1);
                if (start != end) {
                    lazy[node * 2 + 1] += lazy[node];
                    lazy[node * 2 + 2] += lazy[node];
                }
                lazy[node] = 0;
            }
        }

        // Function to update a range of values
        void updateRange(int node, int start, int end, int l, int r, int value) {
            propagate(node, start, end);
            if (start > end || start > r || end < l) {
                return;
            }

            if (start >= l && end <= r) {
                tree[node] += value * (end - start + 1);
                if (start != end) {
                    lazy[node * 2 + 1] += value;
                    lazy[node * 2 + 2] += value;
                }
                return;
            }

            int mid = (start + end) / 2;
            updateRange(node * 2 + 1, start, mid, l, r, value);
            updateRange(node * 2 + 2, mid + 1, end, l, r, value);
            tree[node] = tree[node * 2 + 1] + tree[node * 2 + 2];
        }

        // Function to query the sum in a range
        int queryRange(int node, int start, int end, int l, int r) {
            propagate(node, start, end);
            if (start > end || start > r || end < l) {
                return 0;
            }

            if (start >= l && end <= r) {
                return tree[node];
            }

            int mid = (start + end) / 2;
            int left_query = queryRange(node * 2 + 1, start, mid, l, r);
            int right_query = queryRange(node * 2 + 2, mid + 1, end, l, r);
            return left_query + right_query;
        }

    public:
        // Constructor to initialize the tree and lazy arrays
        SegmentTree(int n) {
            memset(tree, 0, sizeof(tree));
            memset(lazy, 0, sizeof(lazy));
        }

        // Function to update a range
        void update(int l, int r, int value) {
            updateRange(0, 0, MAX - 1, l, r, value);
        }

        // Function to query the sum in a range
        int query(int l, int r) {
            return queryRange(0, 0, MAX - 1, l, r);
        }
};

int main() {
    int n = 5; 
    SegmentTree st(n);
    st.update(1, 3, 5);
    cout << "Sum in range [2, 4]: " << st.query(2, 4) << endl;
    st.update(0, 2, 3);
    cout << "Sum in range [1, 3]: " << st.query(1, 3) << endl;
    return 0;
}

Output:


3. Maximum Prefix Sum in a Range

This program uses a segment tree to calculate the maximum prefix sum in a specified range of an array. It constructs the segment tree using a build function where each node stores the maximum prefix sum for that segment of the array. The query function then finds the maximum prefix sum in a given range by traversing the segment tree and combining the results of the relevant segments.

Code:

#include <iostream>
#include <climits>
using namespace std;

#define MAX 1000    // Maximum size of array

// Segment Tree class
class SegmentTree {
    private:
        int seg_tree[MAX * 4];   // Segment tree array

        // Function to build the segment tree
        void build(int this_case, int start, int end, int arr[]) {
            if (start == end) {
                seg_tree[this_case] = arr[start];   // Leaf node stores the element
            } else {
                int mid = (start + end) / 2;
                build(2 * this_case + 1, start, mid, arr);     // Left child
                build(2 * this_case + 2, mid + 1, end, arr);   // Right child
                seg_tree[this_case] = max(seg_tree[2 * this_case + 1], seg_tree[2 * this_case + 2]);
                // Store maximum prefix sum in current node
            }
        }

        // Function to query the segment tree
        int query(int this_case, int start, int end, int l, int r) {
            if (r < start || end < l) {
                return INT_MIN;     // No overlap
            }
            if (l <= start && end <= r) {
                return seg_tree[this_case];  // Total overlap
            }

            int mid = (start + end) / 2;
            int left_result = query(2 * this_case + 1, start, mid, l, r);       // Left query
            int right_result = query(2 * this_case + 2, mid + 1, end, l, r);    // Right query
            return max(left_result, right_result);      // Return maximum prefix sum
        }

    public:
        // Function to initialize the segment tree
        void init(int arr[], int n) {
            build(0, 0, n - 1, arr);    // Build segment tree for given array
        }

        // Function to get maximum prefix sum in a range [l, r]
        int get_max_prefix_sum(int l, int r) {
            return query(0, 0, MAX - 1, l, r);      // Query the segment tree
        }
};

int main() {
    int arr[MAX];
    int n, q;
    cout << "Enter the number of elements: ";
    cin >> n;
    cout << "Enter the elements: ";
    for (int i = 0; i < n; i++) {
        cin >> arr[i];
    }
    SegmentTree segment_tree;
    segment_tree.init(arr, n);
    cout << "Enter the number of queries: ";
    cin >> q;
    for (int i = 0; i < q; i++) {
        int l, r;
        cout << "Enter the range (l r): ";
        cin >> l >> r;
        cout << "Maximum Prefix Sum in range [" << l << ", " << r << "] is: " << segment_tree.get_max_prefix_sum(l, r) << endl;
    }
    return 0;
}

Output:


4. Number of Distinct Elements in a Range

This program uses a Segment Tree to efficiently find the number of distinct elements in a given range [l, r] of an array. The SegmentTree class has a build method to construct the tree and a query method to get the distinct elements in a range. For each node in the segment tree, a set of elements present in that segment is stored. This allows us to merge the sets of left and right children when querying a range. The get_distinct_count function calls query on the segment tree and returns the size of the set, which represents the number of distinct elements in the range.

Code:

#include <iostream>
#include <cstring>
#include <set>
using namespace std;

#define MAX 1000

class SegmentTree {
    private:
        set<int> segment_tree[4 * MAX];
        int arr[MAX];

        // Helper function to build the segment tree
        void build(int node, int start, int end) {
            if (start == end) {
                segment_tree[node].insert(arr[start]);
            } else {
                int mid = (start + end) / 2;
                build(2 * node, start, mid);
                build(2 * node + 1, mid + 1, end);
                segment_tree[node].insert(segment_tree[2 * node].begin(), segment_tree[2 * node].end());
                segment_tree[node].insert(segment_tree[2 * node + 1].begin(), segment_tree[2 * node + 1].end());
            }
        }

        // Helper function to query the number of distinct elements in range [l, r]
        set<int> query(int node, int start, int end, int l, int r) {
            if (r < start || end < l) {
                return set<int>();  // Return an empty set for out-of-range
            }

            if (l <= start && end <= r) {
                return segment_tree[node];
            }

            int mid = (start + end) / 2;
            set<int> left = query(2 * node, start, mid, l, r);
            set<int> right = query(2 * node + 1, mid + 1, end, l, r);
            left.insert(right.begin(), right.end());
            return left;
        }

    public:
        // Function to initialize the array and build the segment tree
        void init(int n, int input_arr[]) {
            memcpy(arr, input_arr, sizeof(arr));
            build(1, 0, n - 1);
        }

        // Function to get the number of distinct elements in the range [l, r]
        int get_distinct_count(int l, int r) {
            set<int> result = query(1, 0, MAX - 1, l, r);
            return result.size();
        }
};

int main() {
    int n;
    cout << "Enter the size of the array: ";
    cin >> n;
    int arr[MAX];
    cout << "Enter the array elements: ";
    for (int i = 0; i < n; i++) {
        cin >> arr[i];
    }
    SegmentTree tree;
    tree.init(n, arr);
    int q;
    cout << "Enter the number of queries: ";
    cin >> q;
    while (q--) {
        int l, r;
        cout << "Enter the range (l r): ";
        cin >> l >> r;
        cout << "Number of distinct elements in range [" << l << ", " << r << "] is: " 
             << tree.get_distinct_count(l, r) << endl;
    }
    return 0;
}

Output:


5. Solve the "merge k sorted arrays using a segment tree" problem.

This program merges k sorted arrays using a Segment Tree. It first flattens the k sorted arrays into one single array and then builds a segment tree on this array. The segment tree is built in such a way that each internal node holds the minimum value of its respective range. Then, by querying the segment tree, the program can efficiently retrieve the merged sorted array. The Segment Tree helps in managing and retrieving minimum values efficiently.

Code:

#include <iostream>
#include <climits>
using namespace std;

#define MAX 1000    // Maximum number of elements in each array
#define INF INT_MAX

// Segment Tree to merge k sorted arrays
class SegmentTree {
    private:
        int tree[4 * MAX];  // Segment Tree array
        int n;              // Size of the array

        // Function to build the segment tree
        void build(int arr[], int node, int start, int end) {
            if (start == end) {
                tree[node] = arr[start];  // Leaf node holds the element value
            } else {
                int mid = (start + end) / 2;
                build(arr, 2 * node, start, mid);           // Left child
                build(arr, 2 * node + 1, mid + 1, end);     // Right child
                tree[node] = min(tree[2 * node], tree[2 * node + 1]);  // Internal node holds the minimum value
            }
        }

        // Function to update an element in the segment tree
        void update(int node, int start, int end, int idx, int value) {
            if (start == end) {
                tree[node] = value;  // Update the value at the leaf node
            } else {
                int mid = (start + end) / 2;
                if (start <= idx && idx <= mid) {
                    update(2 * node, start, mid, idx, value);           // Left child
                } else {
                    update(2 * node + 1, mid + 1, end, idx, value);     // Right child
                }
                tree[node] = min(tree[2 * node], tree[2 * node + 1]);   // Update the internal node
            }
        }

        // Function to get the minimum value in the range
        int query(int node, int start, int end, int l, int r) {
            if (r < start || end < l) {
                return INF;     // Out of range, return infinity
            }
            if (l <= start && end <= r) {
                return tree[node];      // Current segment is completely inside the query range
            }
            int mid = (start + end) / 2;
            int left_query = query(2 * node, start, mid, l, r);         // Query the left child
            int right_query = query(2 * node + 1, mid + 1, end, l, r);  // Query the right child
            return min(left_query, right_query);    // Return the minimum of both
        }

    public:
        // Constructor to initialize the segment tree
        SegmentTree(int arr[], int size) {
            n = size;
            build(arr, 1, 0, n - 1);    // Build the segment tree from the array
        }

        // Update an element in the segment tree
        void update(int idx, int value) {
            update(1, 0, n - 1, idx, value);
        }

        // Get the minimum value in the range [l, r]
        int query(int l, int r) {
            return query(1, 0, n - 1, l, r);
        }
};

// Function to merge k sorted arrays using a segment tree
void merge_k_sorted_arrays(int k, int arr[][MAX], int rows, int cols) {
    int result[rows * cols];
    int idx = 0;

    // Flatten the k sorted arrays into one array
    for (int i = 0; i < k; i++) {
        for (int j = 0; j < cols; j++) {
            result[idx++] = arr[i][j];
        }
    }

    // Build the segment tree on the flattened array
    SegmentTree seg_tree(result, rows * cols);

    // Output the merged sorted array
    cout << "Merged Sorted Array: ";
    for (int i = 0; i < rows * cols; i++) {
        cout << seg_tree.query(i, i) << " ";    // Each element is returned as a min query result
    }
    cout << endl;
}

int main() {
    int k = 3; 
    int arr[3][MAX] = {
        {1, 5, 9}, 
        {2, 6, 8},  
        {3, 7, 10}  
    };
    int rows = 3, cols = 3; 
    merge_k_sorted_arrays(k, arr, rows, cols);
    return 0;
}

Output:


Day 34 delved into the advanced techniques of segment trees, with a focus on lazy propagation for efficient range updates. These problems emphasized how segment trees can handle complex queries like range sums, maximum prefix sums, and distinct elements in a range. Mastering these techniques is essential for solving real-world problems involving large datasets and complex range operations. 🚀