from sklearn.ensemble import RandomForestClassifier
from itertools import combinations
from db import get_columns
from graph import build_graph
import numpy as np

columns = get_columns()
G = build_graph()

def jaccard(a, b):
    return len(set(a) & set(b)) / len(set(a) | set(b))

def train_model():
    X, y = [], []
    for t1, t2 in combinations(columns.keys(), 2):
        sim = jaccard(columns[t1], columns[t2])
        label = int(G.has_edge(t1, t2))
        X.append([sim])
        y.append(label)
    clf = RandomForestClassifier()
    clf.fit(X, y)
    return clf

clf = train_model()

def predict_joinability(t1, t2):
    sim = jaccard(columns[t1], columns[t2])
    return clf.predict_proba([[sim]])[0][1]

def best_path(source, target):
    paths = list(nx.all_simple_paths(G, source=source, target=target, cutoff=4))
    scored = [(p, np.mean([predict_joinability(p[i], p[i+1]) for i in range(len(p)-1)])) for p in paths]
    return max(scored, key=lambda x: x[1])[0] if scored else []

def best_paths(source, targets):
    return [best_path(source, t) for t in targets if best_path(source, t)]