Skip to content

Instantly share code, notes, and snippets.

@moble
Created March 7, 2025 16:31
Show Gist options
  • Save moble/09e07f433022d65d83b2b0a90e328e22 to your computer and use it in GitHub Desktop.
Save moble/09e07f433022d65d83b2b0a90e328e22 to your computer and use it in GitHub Desktop.
import marimo
__generated_with = "0.11.17"
app = marimo.App(
width="medium",
app_title="SXS Catalog",
css_file="custom.css",
)
@app.cell
def _():
import marimo as mo
return (mo,)
@app.cell(hide_code=True)
def _(mo):
mo.md(
r"""
# The SXS Catalog of Simulations
This page presents the simulations published by the SXS collaboration. The core object here is [the dataframe](https://sxs.readthedocs.io/en/main/api/simulations/#simulationsdataframe-class) that can be loaded using [the `sxs` package](https://github.com/sxs-collaboration/sxs/), which can be done in your own scripts as
```python
import sxs
df = sxs.load("dataframe")
```
"""
).style({"width": "80%", "margin": "0 auto"})
return
@app.cell(hide_code=True)
def _():
### NOTE: This cell is hidden from the user in marimo's "app view".
### These commands are mostly here for nicer display; actual users
### will probably only need the commands mentioned above.
import sxs
import numpy as np
import math
import warnings
import pyarrow # For efficient dataframe manipulation
import pandas as pd
# TODO: Pick which one of these is better
import plotly.express as px
import altair as alt
sim = sxs.load("simulations", local=True)
df0 = sim.dataframe # We filter this below
# The df object is actually a sxs.SiμlationsDataFramesxs.SimulationsDataFrame (so that we can have attributes like df.BBHdf.BBH),
# which subclasses — but is not a — pd.DataFramepd.DataFrame, so the fancy display doesn't work directly.
#
# Fortunately, we can get the fancy display either by calling m⊙ui.dataame(df)mo.ui.dataame(df) or by acting on dfdf
# with some function that returns a regular pd.DataFramepd.DataFrame.
return alt, df0, math, np, pd, px, pyarrow, sim, sxs, warnings
@app.cell
def _(mo):
mo.md(r"""---""")
return
@app.cell
def _(mo):
mo.md(r"""The dataframe has several [useful attributes](https://sxs.readthedocs.io/en/main/api/simulations/#simulationsdataframe-class) that allow selecting important subsets of the data. Use the radio buttons below to select those attributes.""")
return
@app.cell(hide_code=True)
def _(mo):
system_type = mo.ui.radio(
["BBH", "IMR (BBH)", "BHNS", "NSNS", "any"],
value="any",
inline=True,
label="System type:"
)
eccentricity = mo.ui.radio(
["eccentric", "noneccentric", "hyperbolic", "any"],
value="any",
inline=True,
label="Eccentricity:"
)
precession = mo.ui.radio(
["precessing", "nonprecessing", "any"],
value="any",
inline=True,
label="Precession:"
)
deprecation = mo.ui.radio(
["deprecated", "undeprecated", "any"],
value="undeprecated",
inline=True,
label="Deprecation:"
)
df_attributes = mo.vstack([
system_type,
eccentricity,
precession,
deprecation,
])
df_attributes
return deprecation, df_attributes, eccentricity, precession, system_type
@app.cell(hide_code=True)
def _(deprecation, df0, eccentricity, precession, system_type):
# Now, we apply the selections from above to the dataframe
df = df0 # Keep the original around as df0
if system_type.value == "BBH":
df = df.BBH
elif system_type.value == "IMR (BBH)":
df = df.IMR
elif system_type.value == "BHNS":
df = df.BHNS
elif system_type.value == "NSNS":
df = df.NSNS
if eccentricity.value == "eccentric":
df = df.eccentric
elif eccentricity.value == "noneccentric":
df = df.noneccentric
elif eccentricity.value == "hyperbolic":
df = df.hyperbolic
if precession.value == "precessing":
df = df.precessing
elif precession.value == "nonprecessing":
df = df.nonprecessing
if deprecation.value == "deprecated":
df = type(df)(df[df["deprecated"]])
elif deprecation.value == "undeprecated":
df = df.undeprecated
# And drop any columns that pandas couldn't interpret (these are the 3-vector columns)
df = df.select_dtypes(exclude=object)
# Finally, we re-order the columns to put these at the front
preferred_columns = [
"number_of_orbits", "reference_mass_ratio", "reference_chi_eff", "reference_chi1_perp",
"reference_chi2_perp", "reference_eccentricity", "reference_chi1_mag", "reference_chi2_mag"
]
columns = preferred_columns + [c for c in df.columns if c not in preferred_columns]
df = df[columns]
df["number_of_orbits2"] = df["number_of_orbits"]
return columns, df, preferred_columns
@app.cell(hide_code=True)
def _(mo, simple_filtering):
(
mo.md("---\nClick the column headings to sort or filter by any value.")
if simple_filtering.value
else mo.md("---\nClick the column headings to sort. Add a transform to filter or otherwise alter the data table.")
)
return
@app.cell
def _(mo):
simple_filtering = mo.ui.checkbox(label="Simple filtering")
simple_filtering
return (simple_filtering,)
@app.cell
def _(df, mo, simple_filtering):
if simple_filtering.value:
table = mo.ui.table(df, page_size=25, show_column_summaries=True, max_columns=df.shape[1]+1)
else:
table = mo.ui.dataframe(df, page_size=25)
table
return (table,)
@app.cell(hide_code=True)
def _(mo, simple_filtering):
plot_message = "We now plot selected columns from the filtered data above."
if simple_filtering.value:
plot_message += " If you have selected checkboxes in the table, only those will be plotted."
mo.md(plot_message)
return (plot_message,)
@app.cell(hide_code=True)
def _(table):
table_data = table.value if len(table.value) > 0 else table.data
return (table_data,)
@app.cell(hide_code=True)
def _(df, mo):
horizontal_axis = mo.ui.dropdown(
options=df.columns.to_list(),
value="reference_mass_ratio",
label="Horizontal axis",
allow_select_none=False,
)
vertical_axis = mo.ui.dropdown(
options=df.columns.to_list(),
value="reference_chi_eff",
label="Vertical axis",
allow_select_none=False,
)
marker_size = mo.ui.dropdown(
options=df.columns.to_list(),
label="Marker size",
)
marker_color = mo.ui.dropdown(
options=df.columns.to_list(),
label="Marker color",
)
selectors = [horizontal_axis, vertical_axis, marker_size, marker_color]
mo.vstack([
mo.hstack([horizontal_axis, vertical_axis], justify="space-around"),
mo.hstack([marker_size, marker_color], justify="space-around")
], justify="space-around")
return horizontal_axis, marker_color, marker_size, selectors, vertical_axis
@app.cell(hide_code=True)
def _(
alt,
horizontal_axis,
marker_color,
marker_size,
mo,
selectors,
table_data,
vertical_axis,
):
used_selectors = [s.value for s in selectors if s.value is not None]
df_restricted = table_data[used_selectors]
kwargs = dict(
x=horizontal_axis.value,
y=vertical_axis.value,
tooltip=used_selectors,
)
if marker_size.value is not None:
kwargs["size"] = marker_size.value
if marker_color.value is not None:
# Possible color schemes are listed on https://vega.github.io/vega/docs/schemes/
kwargs["color"] = alt.Color(marker_color.value).scale(scheme="viridis")
chart = mo.ui.altair_chart(
alt.Chart(df_restricted, height=400)
.mark_circle(
stroke="#303030",
strokeWidth=1,
opacity=0.8,
)
.encode(**kwargs),
legend_selection=True,
label="SXS Simulations"
)
chart
return chart, df_restricted, kwargs, used_selectors
@app.cell
def _(chart):
chart_data = chart.value if len(chart.value) > 0 else chart.data
return (chart_data,)
@app.cell
def _(chart_data, math, mo):
max_width = 6 # How many SXS IDs to allow on one line
these_simulations = ("this simulation" if len(chart_data)==1 else f"these {len(chart_data)} simulations")
(
(
mo.md(
f"You can load {these_simulations} with\n"
"```python\n"
"sims = [sxs.load(sim) for sim in [\"" + "\", \"".join(chart_data.index) + "\"]]\n"
"`"
)
if len(chart_data) < max_width
else mo.md(
f"You can load {these_simulations} with\n"
+ "```python\n"
+ "sims = [sxs.load(sim) for sim in [\n \""
+ "\",\n \"".join(
"\", \"".join(chart_data.index[i:min(i+max_width, len(chart_data.index))])
for i in range(0, math.ceil(len(chart_data.index)/max_width)*max_width, max_width)
)
+ "\"\n]]\n"
+ "```"
)
)
if len(chart_data) > 0
else None
)
return max_width, these_simulations
@app.cell
def _(mo):
mo.md(r"""The following table shows only the data selected in the plot above.""")
return
@app.cell
def _(chart_data, mo):
(
mo.ui.table(chart_data, page_size=25)
if len(chart_data) > 0
else mo.md("Select a region in the plot above to see details here")
)
return
@app.cell
def _(
horizontal_axis,
marker_color,
mo,
px,
table_data,
used_selectors,
vertical_axis,
):
# used_selectors = [s.value for s in selectors if s.value is not None]
# df_restricted = table_data[used_selectors]
kwargs2 = dict(
x=horizontal_axis.value,
y=vertical_axis.value,
hover_data=used_selectors,
)
# if marker_size.value is not None:
# kwargs2["size"] = table_data[marker_size.value].replace(np.nan, np.inf)
if marker_color.value is not None:
kwargs2["color"] = marker_color.value
_plot = px.scatter(
table_data, **kwargs2
)
plot = mo.ui.plotly(_plot)
plot
return kwargs2, plot
@app.cell
def _():
return
if __name__ == "__main__":
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment