小能豆

In a n-ary tree, find the number of greater children of each node in O(N)?

py

I am having an issue with my algorithm while trying to reduce its complexity. The specific problem has the following input:

  • The first line is an integer N: the total number of nodes in an n-ary tree (1 <= N <= 10^5).
  • The next N lines are the values of each node, called Vi, with 1 <= Vi <= 10^9.
  • The following N-1 lines are the relationships between the child nodes and the parent node (from 2 to N), with the default first node being the root node.

Result: Print N lines, each line containing the total number of child nodes (including lower-level child nodes) that have a value greater than their parent node.

Example:

Input
8
60
20
50
10
40
10
50
30
1
1
1
2
2
3
4
Output
0
1
0
1
0
0
0
0

This is my code, but in the worst case, it has to traverse everything, making its complexity O(N^2). I am hoping for a more optimal solution to reduce the complexity.

import sys
from collections import deque

def calculate(N, values, relations):
    tree = {i: [] for i in range(1, N + 1)}
    for i, parent in enumerate(relations, start=2):
        tree[parent].append(i)

    result = [0] * N

    for i in range(1, N + 1):
        stack = deque([(i, values[i - 1])]) 
        while stack:
            node, parent_value = stack.pop()
            for child in tree[node]:
                if values[child - 1] > parent_value:
                    result[i - 1] += 1
                stack.append((child, parent_value))

    return result

N = int(sys.stdin.readline())
values = [int(sys.stdin.readline()) for _ in range(N)]
relations = [int(sys.stdin.readline()) for _ in range(N - 1)]

for result in calculate(N, values, relations):
    print(result)

阅读 78

收藏
2023-12-12

共1个答案

小能豆

You can improve the time complexity of your algorithm by using a depth-first search (DFS) approach. Instead of traversing the entire subtree for each node, you can accumulate the counts while traversing the tree only once. Here’s an optimized version of your code using DFS:

import sys
from collections import defaultdict

def calculate(N, values, relations):
    tree = defaultdict(list)
    for i, parent in enumerate(relations, start=2):
        tree[parent].append(i)

    result = [0] * N

    def dfs(node):
        nonlocal result
        count = 0
        for child in tree[node]:
            child_count = dfs(child)
            if values[child - 1] > values[node - 1]:
                count += child_count + 1
            else:
                count += child_count
        result[node - 1] = count
        return count

    dfs(1)
    return result

N = int(sys.stdin.readline())
values = [int(sys.stdin.readline()) for _ in range(N)]
relations = [int(sys.stdin.readline()) for _ in range(N - 1)]

for result in calculate(N, values, relations):
    print(result)

This DFS approach has a time complexity of O(N) since it traverses each node once and computes the counts efficiently without the need for multiple passes through the tree.

2023-12-12