import numpy as np from math import prod debug = False def load_junctions(filename: str): """ Load junction file into memory. """ data = np.loadtxt(filename, dtype=int, delimiter=",") return data def distances(junctions, sorted=True): """ Use the upper triangle to calculate distances so we don't do it twice per pair. NOTE: The upper triangle includes the pairs of each junction with itself. """ jlen = len(junctions) X, Y = np.triu_indices(jlen) # Ignore distances if x == y. non_self_identities = X != Y X = X[non_self_identities] Y = Y[non_self_identities] # NOTE: I do note use np.zeros_like as I want the dtype to be the default (float). dists = np.zeros(X.shape) # For each pair, calculate the distance. for i in range(len(X)): x, y = X[i], Y[i] dists[i] = np.linalg.norm(junctions[x] - junctions[y]) # Sort the distances, X, and Y from shortest to longest distance. if sorted: sorted_indices = np.argsort(dists) X = X[sorted_indices] Y = Y[sorted_indices] dists = dists[sorted_indices] return X, Y, dists def num_of_combinations(length): return length*(length - 1)//2 def is_connected(x, connections): for connection in connections: if x in connection: return True return False def get_circuits(junctions, connections): """ Convert connections into circuits. Start by adding all junction indices as circuits, then iterating over connection to make disjunct sets. """ # circuits = [[x] for x in range(len(junctions))] circuits = [list(x) for x in connections] circuits += [[x] for x in range(len(junctions))] changed = True while changed: changed = False for i in range(len(circuits)): for j in range(i + 1, len(circuits)): for x in circuits[i]: if x in circuits[j]: # NOTE: Elements might appear multiple times per circuit. circuits[i] += circuits[j] circuits[j] = [] changed = True return circuits def product_largest_three_circuits(junctions, num_of_connections): X, Y, dists = distances(junctions) connections = [] for i in range(num_of_connections): x, y = X[i], Y[i] connections.append((x, y)) circuits = get_circuits(junctions, connections) lengths = [len(set(circuit)) for circuit in circuits if circuit != []] lengths.sort() return prod(lengths[-3:]) if __name__ == "__main__": test_junctions = load_junctions("testinput") assert product_largest_three_circuits(test_junctions, 10) == 40 junctions = load_junctions("input") print(product_largest_three_circuits(junctions, 1000))