97 lines
2.5 KiB
Python
97 lines
2.5 KiB
Python
import numpy as np
|
|
from math import prod
|
|
|
|
|
|
debug = False
|
|
|
|
def load_junctions(filename: str):
|
|
"""
|
|
Load junction file into memory.
|
|
"""
|
|
data = np.loadtxt(filename, dtype=int, delimiter=",")
|
|
|
|
return data
|
|
|
|
def distances(junctions, sorted=True):
|
|
"""
|
|
Use the upper triangle to calculate distances so we don't do it twice per
|
|
pair.
|
|
|
|
NOTE: The upper triangle includes the pairs of each junction with itself.
|
|
"""
|
|
jlen = len(junctions)
|
|
X, Y = np.triu_indices(jlen)
|
|
|
|
# Ignore distances if x == y.
|
|
non_self_identities = X != Y
|
|
X = X[non_self_identities]
|
|
Y = Y[non_self_identities]
|
|
|
|
# NOTE: I do note use np.zeros_like as I want the dtype to be the default (float).
|
|
dists = np.zeros(X.shape)
|
|
|
|
# For each pair, calculate the distance.
|
|
for i in range(len(X)):
|
|
x, y = X[i], Y[i]
|
|
dists[i] = np.linalg.norm(junctions[x] - junctions[y])
|
|
|
|
# Sort the distances, X, and Y from shortest to longest distance.
|
|
if sorted:
|
|
sorted_indices = np.argsort(dists)
|
|
X = X[sorted_indices]
|
|
Y = Y[sorted_indices]
|
|
dists = dists[sorted_indices]
|
|
|
|
return X, Y, dists
|
|
|
|
def num_of_combinations(length):
|
|
return length*(length - 1)//2
|
|
|
|
def is_connected(x, connections):
|
|
for connection in connections:
|
|
if x in connection:
|
|
return True
|
|
return False
|
|
|
|
def get_circuits(junctions, connections):
|
|
"""
|
|
Convert connections into circuits.
|
|
|
|
Start by adding all junction indices as circuits, then iterating over
|
|
connection to make disjunct sets.
|
|
"""
|
|
# circuits = [[x] for x in range(len(junctions))]
|
|
circuits = [list(x) for x in connections]
|
|
circuits += [[x] for x in range(len(junctions))]
|
|
changed = True
|
|
while changed:
|
|
changed = False
|
|
for i in range(len(circuits)):
|
|
for j in range(i + 1, len(circuits)):
|
|
for x in circuits[i]:
|
|
if x in circuits[j]:
|
|
# NOTE: Elements might appear multiple times per circuit.
|
|
circuits[i] += circuits[j]
|
|
circuits[j] = []
|
|
changed = True
|
|
return circuits
|
|
|
|
def product_largest_three_circuits(junctions, num_of_connections):
|
|
X, Y, dists = distances(junctions)
|
|
connections = []
|
|
|
|
for i in range(num_of_connections):
|
|
x, y = X[i], Y[i]
|
|
connections.append((x, y))
|
|
|
|
circuits = get_circuits(junctions, connections)
|
|
lengths = [len(set(circuit)) for circuit in circuits if circuit != []]
|
|
lengths.sort()
|
|
return prod(lengths[-3:])
|
|
|
|
if __name__ == "__main__":
|
|
test_junctions = load_junctions("testinput")
|
|
assert product_largest_three_circuits(test_junctions, 10) == 40
|
|
|
|
junctions = load_junctions("input")
|
|
print(product_largest_three_circuits(junctions, 1000)) |