Skip to content

Commit

Permalink
- Improvement: Improved oblate star model visualisation.
Browse files Browse the repository at this point in the history
hpparvi committed Feb 23, 2021
1 parent 4d30a4b commit 8c08d91
Showing 9 changed files with 425 additions and 81 deletions.
15 changes: 15 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 12 additions & 19 deletions notebooks/osmodel_example_1.ipynb

Large diffs are not rendered by default.

142 changes: 142 additions & 0 deletions notebooks/osmodel_visualization_example.ipynb

Large diffs are not rendered by default.

66 changes: 66 additions & 0 deletions pytransit/models/numba/osmodel.py
Original file line number Diff line number Diff line change
@@ -154,6 +154,72 @@ def luminosity_v(xs, ys, mstar, rstar, ostar, tpole, gpole, f, sphi, cphi, beta,
l[i] = planck(wavelength, t)*(1. - ldc[0]*(1. - mu) - ldc[1]*(1. - mu)**2)
return l

@njit
def luminosity_v2(ps, normals, istar, mstar, rstar, ostar, tpole, gpole, beta, ldc, wavelength):
npt = ps.shape[0]
l = zeros(npt)
dc = zeros(3)

vx = 0.0
vy = -cos(istar)
vz = -sin(istar)

for i in range(npt):
px, py, pz = ps[i] * rstar # Position vector components
nx, ny, nz = normals[i] # Normal vector components

mu = vy*ny + vz*nz

lp2 = (px**2 + py**2 + pz**2) # Squared distance from center
lc = sqrt(px**2 + pz**2) # Centrifugal vector length
cx, cz = px/lc, pz/lc # Normalized centrifugal vector

gg = -G * mstar / lp2 # Newtionian surface gravity component
gc = ostar * ostar * lc # Centrifugal surface gravity component

gx = gg*nx + gc*cx # Surface gravity x component
gy = gg*ny # Surface gravity y component
gz = gg*nz + gc*cz # Surface gravity z component

g = sqrt((gx**2 + gy**2 + gz**2)) # Surface gravity
t = tpole*g**beta / gpole**beta # Temperature [K]
l[i] = planck(wavelength, t) # Thermal radiation
l[i] *= (1.-ldc[0]*(1.-mu) - ldc[1]*(1.-mu)**2) # Quadratic limb darkening


return l


@njit
def luminosity_s2(p, normal, istar, mstar, rstar, ostar, tpole, gpole, beta, ldc, wavelength):

vx = 0.0
vy = -cos(istar)
vz = -sin(istar)

px, py, pz = p * rstar # Position vector components
nx, ny, nz = normal # Normal vector components

mu = vy*ny + vz*nz

lp2 = (px**2 + py**2 + pz**2) # Squared distance from center
lc = sqrt(px**2 + pz**2) # Centrifugal vector length
cx, cz = px/lc, pz/lc # Normalized centrifugal vector

gg = -G * mstar / lp2 # Newtionian surface gravity component
gc = ostar * ostar * lc # Centrifugal surface gravity component

gx = gg*nx + gc*cx # Surface gravity x component
gy = gg*ny # Surface gravity y component
gz = gg*nz + gc*cz # Surface gravity z component

g = sqrt((gx**2 + gy**2 + gz**2)) # Surface gravity
t = tpole*g**beta / gpole**beta # Temperature [K]
l = planck(wavelength, t)
l *= (1.-ldc[0]*(1.-mu) - ldc[1]*(1.-mu)**2) # Quadratic limb darkening

return l


def create_star_xy(res: int = 64):
st = linspace(-1., 1., res)
129 changes: 77 additions & 52 deletions pytransit/models/osmodel.py
Original file line number Diff line number Diff line change
@@ -16,12 +16,18 @@
from typing import Union

from astropy.constants import R_sun, M_sun
from matplotlib.patches import Circle
from matplotlib.pyplot import subplots, setp
from numpy import linspace, meshgrid, sin, cos, array, ndarray, asarray, squeeze
from numpy import linspace, sin, cos, array, ndarray, asarray, squeeze, cross, newaxis, pi, where, nan, full, degrees
from numpy.linalg import norm
from scipy.spatial.transform.rotation import Rotation

from .transitmodel import TransitModel
from .numba.osmodel import create_star_xy, create_planet_xy, map_osm, xy_taylor_vt, luminosity_v, oblate_model_s
from ..orbits import i_from_ba
from .numba.osmodel import create_star_xy, create_planet_xy, map_osm, xy_taylor_vt, luminosity_v, oblate_model_s, \
luminosity_v2
from ..orbits import as_from_rhop, i_from_baew
from ..orbits.taylor_z import vajs_from_paiew, find_contact_point
from ..utils.octasphere import octasphere


class OblateStarModel(TransitModel):
@@ -54,55 +60,74 @@ def __init__(self, rstar: float = 1.0, wavelength: float = 510, sres: int = 80,
self._ts, self._xs, self._ys = create_star_xy(sres)
self._xp, self._yp = create_planet_xy(pres)

def visualize(self, k, b, alpha, rho, rperiod, tpole, phi, beta, ldc, ires: int = 256):
"""Visualize the model for a set of parameters.
Parameters
----------
k
b
alpha
rho
rperiod
tpole
phi
beta
ldc
ires
Returns
-------
"""
a = 4.5
mstar, ostar, gpole, f, feff = map_osm(self.rstar, rho, rperiod, tpole, phi)
i = i_from_ba(b, a)
times = linspace(-1.1, 1.1)
ox, oy = xy_taylor_vt(times, alpha, -b, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)

x = linspace(-1.1, 1.1, ires)
y = linspace(-1.1, 1.1, ires)
x, y = meshgrid(x, y)
sphi, cphi = sin(phi), cos(phi)

l = luminosity_v(x.ravel()*self.rstar, y.ravel()*self.rstar, mstar, self.rstar, ostar, tpole, gpole,
f, sphi, cphi, beta, ldc, self.wavelength)

fig, axs = subplots(1, 2, figsize=(13, 4))
axs[0].imshow(l.reshape(x.shape), extent=(-1.1, 1.1, -1.1, 1.1), origin='lower')
axs[0].plot(ox, oy, 'w', lw=5, alpha=0.25)
axs[0].plot(ox, oy, 'k', lw=2)

setp(axs[0], ylabel='y [R$_\star$]', xlabel='x [R$_\star$]')

times = linspace(-0.35, 0.35, 500)
flux = oblate_model_s(times, array([k]), 0.0, 4.0, a, alpha, i, 0.0, 0.0, ldc, mstar, self.rstar, ostar, tpole, gpole,
f, feff, sphi, cphi, beta, self.wavelength, self.tres, self._ts, self._xs, self._ys, self._xp, self._yp,
self.lcids, self.pbids, self.nsamples, self.exptimes, self.npb)

axs[1].plot(times, flux, 'k')
setp(axs[1], ylabel='Normalized flux', xlabel='Time - T$_0$')
fig.tight_layout()
def visualize(self, k, p, rho, b, e, w, alpha, rperiod, tpole, istar, beta, ldc, figsize=(5, 5), ax=None,
ntheta=18):
if ax is None:
fig, ax = subplots(figsize=figsize)
ax.set_aspect(1.)
else:
fig, ax = None, ax

a = as_from_rhop(rho, p)
inc = i_from_baew(b, a, e, w)
mstar, ostar, gpole, f, _ = map_osm(rstar=self.rstar, rho=rho, rperiod=rperiod, tpole=tpole, phi=0.0)

# Plot the star
# -------------
vertices_original, faces = octasphere(4)
vertices = vertices_original.copy()
vertices[:, 1] *= (1.0 - f)

triangles = vertices[faces]
centers = triangles.mean(1)
normals = cross(triangles[:, 1] - triangles[:, 0], triangles[:, 2] - triangles[:, 0])
nlength = norm(normals, axis=1)
normals /= nlength[:, newaxis]

rotation = Rotation.from_rotvec((0.5 * pi - istar) * array([1, 0, 0]))
rn = rotation.apply(normals)
rc = rotation.apply(centers)

mask = rn[:, 2] < 0.0
l = luminosity_v2(centers[mask], normals[mask], istar, mstar, self.rstar, ostar, tpole, gpole, beta,
ldc, self.wavelength)
ax.tripcolor(rc[mask, 0], rc[mask, 1], l, shading='gouraud')

nphi = 180
theta = linspace(0 + 0.1, pi - 0.1, ntheta)
phi = linspace(0, 2 * pi, nphi)
for i in range(theta.size):
y = (1.0 - f) * cos(theta[i])
x = cos(phi) * sin(theta[i])
z = sin(phi) * sin(theta[i])
v = rotation.apply(array([x, full(nphi, y), z]).T)
m = v[:, 2] < 0.0
ax.plot(where(m, v[:, 0], nan), v[:, 1], 'k--', lw=1.5, alpha=0.25)

# Plot the orbit
# --------------
y0, vx, vy, ax_, ay, jx, jy, sx, sy = vajs_from_paiew(p, a, inc, e, w)
c1 = find_contact_point(k, 1, y0, vx, vy, ax_, ay, jx, jy, sx, sy)
c4 = find_contact_point(k, 4, y0, vx, vy, ax_, ay, jx, jy, sx, sy)
time = linspace(2 * c1, 2 * c4, 100)

ox, oy = xy_taylor_vt(time, alpha, y0, vx, vy, ax_, ay, jx, jy, sx, sy)
ax.plot(ox, oy, 'k')

pxy = xy_taylor_vt(array([0.0]), alpha, y0, vx, vy, ax_, ay, jx, jy, sx, sy)
ax.add_artist(Circle(pxy, k, zorder=10, fc='k'))

# Plot the info
# -------------
ax.text(0.025, 0.95, f"i$_\star$ = {degrees(istar):.1f}$^\circ$", transform=ax.transAxes)
ax.text(0.025, 0.90, f"i$_\mathrm{{p}}$ = {degrees(inc):.1f}$^\circ$", transform=ax.transAxes)
ax.text(1 - 0.025, 0.95, fr"$\alpha$ = {degrees(alpha):.1f}$^\circ$", transform=ax.transAxes, ha='right')
ax.text(0.025, 0.05, f"f = {f:.1f}", transform=ax.transAxes)

setp(ax, xlim=(-1.1, 1.1), ylim=(-1.1, 1.1), xticks=[], yticks=[])
if fig is not None:
fig.tight_layout()
return ax

def evaluate_ps(self, k: Union[float, ndarray], rho: float, rperiod: float, tpole: float, phi: float,
beta: float, ldc: ndarray, t0: float, p: float, a: float, i: float, l: float = 0.0,
97 changes: 97 additions & 0 deletions pytransit/utils/octasphere.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# PyTransit: fast and easy exoplanet transit modelling in Python.
# Copyright (C) 2010-2021 Hannu Parviainen
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

# This script can generate spheres, rounded cubes, and capsules.
# For more information, see https://prideout.net/blog/octasphere/
# Copyright (c) 2019 Philip Rideout
# Distributed under the MIT License, see bottom of file.

from math import sin, cos, acos, pi
from numpy import empty, array, vstack, cross, dot
from pyrr import quaternion


def octasphere(ndivisions: int):
"""Creates a unit sphere using octagon subdivision.
Creates a unit sphere using octagon subdivision. Modified slightly from the original code
by Philip Rideout (https://prideout.net/blog/octasphere).
"""

n = 2**ndivisions + 1
num_verts = n * (n + 1) // 2
verts = empty((num_verts, 3))
j = 0
for i in range(n):
theta = pi * 0.5 * i / (n - 1)
point_a = [0, sin(theta), cos(theta)]
point_b = [cos(theta), sin(theta), 0]
num_segments = n - 1 - i
j = compute_geodesic(verts, j, point_a, point_b, num_segments)
assert len(verts) == num_verts

num_faces = (n - 2) * (n - 1) + n - 1
faces = empty((num_faces, 3), dtype='int')
f, j0 = 0, 0
for col_index in range(n-1):
col_height = n - 1 - col_index
j1 = j0 + 1
j2 = j0 + col_height + 1
j3 = j0 + col_height + 2
for row in range(col_height - 1):
faces[f + 0] = [j0 + row, j1 + row, j2 + row]
faces[f + 1] = [j2 + row, j1 + row, j3 + row]
f = f + 2
row = col_height - 1
faces[f] = [j0 + row, j1 + row, j2 + row]
f = f + 1
j0 = j2

euler_angles = array([
[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0],
[1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3],
]) * pi * 0.5
quats = (quaternion.create_from_eulers(e) for e in euler_angles)

offset, combined_verts, combined_faces = 0, [], []
for quat in quats:
rotated_verts = [quaternion.apply_to_vector(quat, v) for v in verts]
rotated_faces = faces + offset
combined_verts.append(rotated_verts)
combined_faces.append(rotated_faces)
offset = offset + len(verts)

return vstack(combined_verts), vstack(combined_faces)


def compute_geodesic(dst, index, point_a, point_b, num_segments):
"""Given two points on a unit sphere, returns a sequence of surface
points that lie between them along a geodesic curve."""

angle_between_endpoints = acos(dot(point_a, point_b))
rotation_axis = cross(point_a, point_b)
dst[index] = point_a
index = index + 1
if num_segments == 0:
return index
dtheta = angle_between_endpoints / num_segments
for point_index in range(1, num_segments):
theta = point_index * dtheta
q = quaternion.create_from_axis_rotation(rotation_axis, theta)
dst[index] = quaternion.apply_to_vector(q, point_a)
index = index + 1
dst[index] = point_b
return index + 1
2 changes: 1 addition & 1 deletion pytransit/version.py
Original file line number Diff line number Diff line change
@@ -16,4 +16,4 @@

from semantic_version import Version

__version__ = Version('2.5.3')
__version__ = Version('2.5.4')
22 changes: 14 additions & 8 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
numpy>=1.19.0
scipy>=1.5.0
pandas~=1.1.1
xarray~=0.12.3
pandas~=1.2.2
xarray~=0.12.1
tables
uncertainties~=3.1.1
numba~=0.50.1
astropy~=4.0.1.post1
numba~=0.51.2
astropy~=4.2
matplotlib~=3.3.1
tqdm~=4.48.2
tqdm~=4.31.1
semantic_version>=2.8
setuptools~=49.6.0
setuptools~=52.0.0
deprecated~=1.2.10
seaborn
emcee
seaborn~=0.11.1
emcee~=0.0.0
pytransit~=2.5.2
ldtk~=1.0
pyopencl~=2019.1.2
corner~=2.0.1
celerite~=0.4.0
pyrr~=0.10.3
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,7 @@
'pytransit.utils', 'pytransit.param', 'pytransit.contamination','pytransit.lpf', 'pytransit.lpf.tess',
'pytransit.lpf.baselines','pytransit.lpf.loglikelihood'],
package_data={'':['*.cl'], 'pytransit.contamination':['data/*']},
install_requires=["numpy", "numba", "scipy", "pandas", "xarray", "tables", "semantic_version","deprecated", "uncertainties"],
install_requires=["numpy", "numba", "scipy", "pandas", "xarray", "tables", "semantic_version", "deprecated", "uncertainties"],
extras_require={'celerite': ["celerite","pybind11"]},
include_package_data=True,
license='GPLv2',

0 comments on commit 8c08d91

Please sign in to comment.