Skip to content

Commit

Permalink
Improving the t-sne plot
Browse files Browse the repository at this point in the history
  • Loading branch information
HarrisonWilde committed Apr 13, 2023
1 parent 3554bad commit efc2418
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/nhssynth/modules/plotting/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,25 @@ def factorize_all_categoricals(


def tsne(
X_gt: pd.DataFrame,
X_syn: pd.DataFrame,
real: pd.DataFrame,
synth: pd.DataFrame,
) -> None:
tsne_gt = TSNE(n_components=2, random_state=0, learning_rate="auto", init="pca")
proj_gt = pd.DataFrame(tsne_gt.fit_transform(factorize_all_categoricals(X_gt)))
tsne_real = TSNE(n_components=2, init="pca")
proj_real = pd.DataFrame(tsne_real.fit_transform(factorize_all_categoricals(real)))

tsne_syn = TSNE(n_components=2, random_state=0, learning_rate="auto", init="pca")
proj_syn = pd.DataFrame(tsne_syn.fit_transform(factorize_all_categoricals(X_syn)))
tsne_synth = TSNE(n_components=2, init="pca")
proj_synth = pd.DataFrame(tsne_synth.fit_transform(factorize_all_categoricals(synth)))

fig = go.Figure()

fig.add_scatter(x=proj_gt[0], y=proj_gt[1], mode="markers", marker=dict(size=5), opacity=0.75, name="Real data")
fig.add_scatter(x=proj_real[0], y=proj_real[1], mode="markers", marker=dict(size=5), opacity=0.75, name="Real data")
fig.add_scatter(
x=proj_syn[0], y=proj_syn[1], mode="markers", marker=dict(size=5), opacity=0.75, name="Synthetic data"
x=proj_synth[0], y=proj_synth[1], mode="markers", marker=dict(size=5), opacity=0.75, name="Synthetic data"
)

# Set axis labels and legend
fig.update_layout(
title="t-SNE plot",
title="t-SNE Plot",
xaxis_title="t-SNE 1",
yaxis_title="t-SNE 2",
legend=dict(x=0, y=1, bgcolor="rgba(255, 255, 255, 0.5)"),
Expand Down

0 comments on commit efc2418

Please sign in to comment.