Skip to content

Instantly share code, notes, and snippets.

@Sunmish
Created February 2, 2023 08:30
Show Gist options
  • Save Sunmish/fa5039d3d892e82ab52c5c4ee015ae32 to your computer and use it in GitHub Desktop.
Save Sunmish/fa5039d3d892e82ab52c5c4ee015ae32 to your computer and use it in GitHub Desktop.
#! /usr/bin/env python
import os
from argparse import ArgumentParser
import numpy as np
from astropy.coordinates import SkyCoord, EarthLocation, AltAz, FK5
from astropy import units as u
from astropy.io import fits
from astropy.time import Time
from subprocess import Popen
import pandas as pd
from sklearn.cluster import DBSCAN, KMeans
from matplotlib.colors import rgb2hex
from matplotlib.cm import get_cmap
MWA = EarthLocation.from_geodetic(lat=-26.703319*u.deg,
lon=116.67081*u.deg,
height=377*u.m)
class Pair():
def __init__(self, obs1, obs2, ra, dec, sep, ant1, ant2):
"""
"""
self.obs1 = obs1
self.obs2 = obs2
self.ra = ra
self.dec = dec
self.sep = sep
self.ant1 = ant1
self.ant2 = ant2
class Group():
def __init__(self):
"""
"""
self.obs = []
self.ra = []
self.dec = []
self.ants = []
def merge(self):
self.obs = np.asarray(self.obs)
c = mean_coordinates(np.array(self.ra), np.array(self.dec))
self.ra = c[0]
self.dec = c[1]
self.ants = np.asarray(self.ants)
def make_region(obslist, color, dash=False):
outname = obslist.replace(".txt", "") + ".reg"
color = rgb2hex(color)
print(color)
if dash:
dash = "-d"
else:
dash = ""
cmd = "obs2reg {obslist} -k -o {outname} -c '{color}' {dash}".format(
obslist=obslist, outname=outname, color=color, dash=dash
)
Popen(cmd, shell=True).wait()
def mean_coordinates(ra, dec):
"""
"""
# TODO: replace with https://docs.astropy.org/en/stable/stats/circ.html
mean_dec = np.mean(dec)
mean_ra = np.mean(
np.degrees(
np.arctan2(
(np.sum(np.sin(np.radians(ra)))/len(ra)),
(np.sum(np.cos(np.radians(ra)))/len(ra))
)
)
)
return mean_ra, mean_dec
def make_metafits(obsid):
with open("make_metafits.log", "a+") as log:
Popen("wget http://ws.mwatelescope.org/metadata/fits/?obs_id={0} -O {0}.metafits".format(obsid),
stdout=log, stderr=log, shell=True).wait()
def match(obs1, obs2, antennas1, antennas2, outname="match.txt", all=False, pair_sep=2., group_sep=4.,
exclude_obs=[], max_groups=1e3, dec_limits=[-90, 90],
cmap="Spectral_r",
regions=False,
):
"""Find nearest matches (in RA/Dec and El/Az.)
"""
cmap = get_cmap(cmap)
# do MWA-2 first
f1 = open(outname.replace(".txt", "_mwa2.txt"), "w+")
f2 = open(outname.replace(".txt", "_mwa1.txt"), "w+")
print("Excluding ", exclude_obs)
obslist1, obslist2, ants1, ants2 = [], [], [], []
for i in range(len(obs1)):
if obs1[i] not in exclude_obs:
obslist1.append(obs1[i])
ants1.append(antennas1[i])
for i in range(len(obs2)):
if obs2[i] not in exclude_obs:
obslist2.append(obs2[i])
ants2.append(antennas2[i])
for obs in obslist1:
if not os.path.exists("{}.metafits".format(obs)):
make_metafits(obs)
for obs in obslist2:
if not os.path.exists("{}.metafits".format(obs)):
make_metafits(obs)
metafits1 = []
metafits2 = []
for obs in obslist1:
metafits1.append(fits.getheader("{}.metafits".format(obs)))
for obs in obslist2:
metafits2.append(fits.getheader("{}.metafits".format(obs)))
radec1 = SkyCoord(ra=np.asarray([metafits["RA"] for metafits in metafits1])*u.deg,
dec=np.asarray([metafits["DEC"] for metafits in metafits1])*u.deg)
radec2 = SkyCoord(ra=np.asarray([metafits["RA"] for metafits in metafits2])*u.deg,
dec=np.asarray([metafits["DEC"] for metafits in metafits2])*u.deg)
times1 =np.asarray([Time(metafits["DATE-OBS"], format="isot", scale="utc") \
for metafits in metafits1])
times2 = np.asarray([Time(metafits["DATE-OBS"], format="isot", scale="utc") \
for metafits in metafits2])
altaz1 = radec1.transform_to(AltAz(obstime=times1, location=MWA))
altaz2 = radec2.transform_to(AltAz(obstime=times2, location=MWA))
m1, m2 = [], []
print("N(MWA1): {}, N(MWA2): {}".format(len(metafits2), len(metafits1)))
pairs = []
obslist2 = np.asarray(obslist2)
ants2 = np.asarray(ants2)
for i in range(len(metafits1)):
print(obslist1[i])
try:
sep1 = radec1[i].separation(radec2)
sep2 = altaz1[i].separation(altaz2)
# idx1, sep1, _ = radec1[i].match_to_catalog_sky(radec2)
# idx2, sep2, _ = altaz1[i].match_to_catalog_sky(altaz2)
# try:
lsm = np.argmin([np.abs(sep1.value)])
print(sep1[lsm].value)
if sep1[lsm].value < pair_sep and obslist2[lsm] not in m2:
m1.append(obslist1[i])
m2.append(obslist2[lsm])
f1.write("{} {}\n".format(obslist1[i], ants1[i]))
f2.write("{} {}\n".format(obslist2[lsm], ants2[lsm]))
ra, dec = mean_coordinates(ra=np.array([radec1[i].ra.value, radec2[lsm].ra.value]),
dec=np.array([radec1[i].dec.value, radec2[lsm].dec.value]))
p = Pair(obslist1[i], obslist2[lsm], ra, dec, sep1[lsm], ants1[i], ants2[lsm])
if (dec_limits[0] <= p.dec <= dec_limits[1]):
print(p.dec)
pairs.append(p)
else:
print("ignoring due to dec limits")
#
obslist2 = np.delete(obslist2, lsm, 0)
ants2 = np.delete(ants2, lsm, 0)
radec2 = np.delete(radec2, lsm, 0)
# altaz2 = np.delete(altaz2, lsm, 0)
radec2 = SkyCoord(ra=np.array([r.ra.value for r in radec2])*u.deg,
dec=np.array([r.dec.value for r in radec2])*u.deg)
times2 = np.delete(times2, lsm, 0)
altaz2 = radec2.transform_to(AltAz(obstime=times2, location=MWA))
except ValueError:
# raise
print("No match for {}".format(obslist1[i]))
f1.close()
f2.close()
print(pairs)
sortie = np.unwrap([np.radians(p.ra) for p in pairs]).argsort()
pairs = [pairs[i] for i in sortie]
print(pairs)
if all:
groups = {}
pcoords = SkyCoord(ra=np.array([p.ra for p in pairs])*u.deg,
dec=np.array([p.dec for p in pairs])*u.deg)
pradians = np.array([np.unwrap([np.radians(p.ra) for p in pairs]),
np.array([np.radians(p.dec) for p in pairs])]).T
print(pradians)
i = 0
eps = np.radians(group_sep)
print(pradians.shape)
# db = DBSCAN(eps=eps, min_samples=max_groups, algorithm="ball_tree",
# metric="haversine").fit(pradians)
db = KMeans(n_clusters=int(max_groups), random_state=0, tol=1e-2).fit(pradians)
for cluster in set(db.labels_):
groups[cluster] = Group()
for i in range(len(db.labels_)):
if db.labels_[i] == cluster:
groups[cluster].obs.append(pairs[i].obs1)
groups[cluster].obs.append(pairs[i].obs2)
groups[cluster].ra.append(pairs[i].ra)
groups[cluster].dec.append(pairs[i].dec)
groups[cluster].ants.append(pairs[i].ant1)
groups[cluster].ants.append(pairs[i].ant2)
groups[cluster].merge()
max_per_group = max([len(groups[i].obs) for i in groups.keys()])
print("max per group: {}".format(max_per_group))
for g in groups.keys():
print("group {}: N={}, N/2={}".format(
g, len(groups[g].obs), len(groups[g].obs)/2
))
colors = [cmap(i/len(groups)) for i in range(len(groups))]
print(colors)
for i in range(len(groups)):
print(groups[i].obs)
f1 = open(outname.replace(".txt", "_group{}_mwa2.txt".format(i)), "w+")
f2 = open(outname.replace(".txt", "_group{}_mwa1.txt".format(i)), "w+")
for j in range(0, len(groups[i].obs), 2):
f1.write("{} {}\n".format(groups[i].obs[j], groups[i].ants[j]))
f2.write("{} {}\n".format(groups[i].obs[j+1], groups[i].ants[j+1]))
f1.close()
f2.close()
if regions:
make_region(outname.replace(".txt", "_group{}_mwa2.txt".format(i)), colors[i],
dash=False)
make_region(outname.replace(".txt", "_group{}_mwa1.txt".format(i)), colors[i],
dash=True)
return m1, m2
def main():
"""
"""
ps = ArgumentParser()
ps.add_argument("obslist1")
ps.add_argument("obslist2")
ps.add_argument("-o", "--outname", default="match.txt")
ps.add_argument("-a", "--all", action="store_true", help="Create files of ALL matches within 2.0 degrees.")
ps.add_argument("-p", "--pair_sep", default=2.0, type=float)
ps.add_argument("-g", "--group_sep", default=4.0, type=float)
ps.add_argument("-e", "--exclude", nargs="*", default=[])
ps.add_argument("--max_groups", default=1, type=float)
ps.add_argument("--dec_limits", nargs=2, default=[-90, 90], type=float)
ps.add_argument("--regions", action="store_true")
args = ps.parse_args()
with open(args.obslist1) as f:
lines = f.readlines()
obsidlist1 = [l.split()[0] for l in lines]
antennas1 = []
for l in lines:
bits = l.split()
if len(bits) > 1:
antennas1.append(bits[1])
else:
antennas1.append("")
with open(args.obslist2) as f:
lines = f.readlines()
obsidlist2 = [l.split()[0] for l in lines]
antennas2 = []
for l in lines:
bits = l.split()
if len(bits) > 1:
antennas2.append(bits[1])
else:
antennas2.append("")
match(obsidlist1, obsidlist2,
outname=args.outname,
all=args.all,
pair_sep=args.pair_sep,
group_sep=args.group_sep,
exclude_obs=args.exclude,
max_groups=args.max_groups,
dec_limits=args.dec_limits,
antennas1=antennas1,
antennas2=antennas2,
regions=args.regions)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment