# standard Python imports
from pathlib import Path
# scientific package imports
import numpy as np
from numpy.linalg import norm
import astropy.units as u
from galaxy.galaxies import Galaxies
[docs]class Approaches(Galaxies):
"""
A class to work with all 3 galaxies when in close proximity.
Args:
snap (int):
Snap number, equivalent to time elapsed.
Defaults to the last timepoint.
datadir (str):
Directory to search first for the required file. Optional, and a
default list of locations will be searched.
usesql (bool):
If True, data will be taken from a PostgreSQL database instead of
text files.
stride (int):
Optional. For stride=n, get every nth row in the table.
Only valid with usesql=True.
ptype (str):
can be 'lum' for disk+bulge, 'dm' for halo
Class attributes:
data (np.ndarray):
type, mass, position_xyz, velocity_xyz for each particle
"""
def __init__(self, snap=801, datadir=None, usesql=False, stride=1, ptype='lum'):
"Initial setup. Currently it calls read_file(), but this may change."
self.snap = snap
self.ptype = ptype
if usesql:
self.read_db(stride)
else:
raise NotImplementedError
[docs] def read_db(self, stride):
"""
Get relevant data from a PostgreSQL database and format it to be
identical to that read from test files.
Args:
stride (int):
Optional. For stride=n, get every nth row in the table.
Changes:
`self.time`, `self.particle_count` and `self.data` are set.
Returns: nothing
"""
from galaxy.db import DB
db = DB()
cur = db.get_cursor()
# set the elapsed time
sql_t = f"SELECT time FROM simdata WHERE galname in ('MW', 'M31')"
sql_t += f" and snap={self.snap} LIMIT 1"
cur.execute(sql_t)
time = cur.fetchone()
try:
self.time = time[0] * u.Myr
except TypeError:
print(self.name, self.snap, ptype)
# set the bulk of the data
colheads = ','.join(['galname','type','m','x','y','z','vx','vy','vz'])
if stride > 1:
sql_d = f"SELECT {colheads}, ROW_NUMBER() OVER () as rn from simdata"
else:
sql_d = f"SELECT {colheads} FROM simdata WHERE snap={self.snap}"
if self.ptype == 'lum':
sql_d += f" and type in (2,3)"
elif self.ptype == 'dm':
sql_d += f" and type=1"
else:
raise ValueError("Valid ptype is 'lum' or 'dm'")
sql_d += " ORDER BY galname, pnum"
if stride > 1:
sql_d = f"SELECT {colheads} from ( {sql_d} ) as t where rn % {stride} = 0"
dtype=[('galname', 'U3'), ('type', 'uint8'), ('m', '<f4'),
('x', '<f4'), ('y', '<f4'), ('z', '<f4'),
('vx', '<f4'), ('vy', '<f4'), ('vz', '<f4')]
cur.execute(sql_d)
self.data = np.array(cur.fetchall(), dtype=dtype)
self.particle_count = len(self.data)
[docs] def xyz(self):
"""
Convenience method to get positions as a np.array of shape (3,N)
"""
return np.array([self.data[xi] for xi in ('x','y','z')])
[docs] def vxyz(self):
"""
Convenience method to get velocities as a np.array of shape (3,N)
"""
return np.array([self.data[vxi] for vxi in ('vx','vy','vz')])