Skip to content

Commit

Permalink
ruff
Browse files Browse the repository at this point in the history
  • Loading branch information
dkweiss31 committed Sep 28, 2024
1 parent e9c2cab commit 9f5593d
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 43 deletions.
5 changes: 3 additions & 2 deletions floquet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

from .amplitude_converters import ChiacToAmp as ChiacToAmp
from .amplitude_converters import XiSqToAmp as XiSqToAmp
from .floquet import DisplacedState as DisplacedState
from .floquet import DisplacedStateFit as DisplacedStateFit
from .floquet import DriveParameters as DriveParameters
from .floquet import floquet_analysis as floquet_analysis
from .floquet import floquet_analysis_from_file as floquet_analysis_from_file
from .floquet import DriveParameters as DriveParameters
from .floquet import (DisplacedState as DisplacedState, DisplacedStateFit as DisplacedStateFit)
from .options import Options
from .utils.file_io import extract_info_from_h5 as extract_info_from_h5
from .utils.file_io import generate_file_path as generate_file_path
Expand Down
113 changes: 74 additions & 39 deletions floquet/floquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,14 @@ def overlap_with_bare_states(
):
"""Similar calculation to _overlap_with_displaced state but here we take
advantage of the fact that the displaced state should only be computed on the
boundary."""
boundary.
"""
overlaps = np.zeros(
(len(self.drive_parameters.omega_d_values),
int(amp_idxs[1] - amp_idxs[0]),
len(self.state_indices),
)
(
len(self.drive_parameters.omega_d_values),
int(amp_idxs[1] - amp_idxs[0]),
len(self.state_indices),
)
)
for array_idx, state_idx in enumerate(self.state_indices):

Expand All @@ -158,7 +160,7 @@ def _compute_bare_state(omega_d: float) -> np.ndarray:
omega_d,
self.drive_parameters.drive_amplitudes[amp_idxs[0], omega_d_idx],
state_idx,
coefficients=coefficients[array_idx]
coefficients=coefficients[array_idx],
).full()[:, 0]

bare_states = np.array(
Expand All @@ -173,7 +175,11 @@ def _compute_bare_state(omega_d: float) -> np.ndarray:
# indices are i: omega_d, j: amp, k: components of state
# this serves as the mask that is passed to _disp_coeffs_fit
overlaps[:, :, array_idx] = np.abs(
np.einsum('ijk,ik->ij', floquet_data[:, amp_idxs[0]: amp_idxs[1], array_idx], np.conj(bare_states))
np.einsum(
'ijk,ik->ij',
floquet_data[:, amp_idxs[0]: amp_idxs[1], array_idx],
np.conj(bare_states),
)
)
return overlaps

Expand All @@ -195,15 +201,20 @@ def _run_overlap_displaced(omega_d_amp: tuple[float, float]) -> np.ndarray:
array_idx,
]
disp_state = self.displaced_state(
omega_d, amp, state_idx=state_idx, coefficients=coefficients[array_idx]
omega_d,
amp,
state_idx=state_idx,
coefficients=coefficients[array_idx],
).dag()
overlap[array_idx] = np.abs(disp_state.data.toarray()[0] @ floquet_data_for_idx)
overlap[array_idx] = np.abs(
disp_state.data.toarray()[0] @ floquet_data_for_idx
)
return overlap

omega_d_amp_params = self.drive_parameters.omega_d_amp_params(amp_idxs)
amp_range_vals = self.drive_parameters.drive_amplitudes[
amp_idxs[0]: amp_idxs[1]
]
amp_idxs[0] : amp_idxs[1]
]
result = list(
parallel_map(
self.options.num_cpus, _run_overlap_displaced, omega_d_amp_params
Expand Down Expand Up @@ -251,7 +262,11 @@ def _coefficient_for_state(
result = 1.0 if bare_same else 0.0
for idx in exp_pair_map:
exp_pair = exp_pair_map[idx]
result += state_idx_coefficients[idx] * omega_d ** exp_pair[0] * amp ** exp_pair[1]
result += (
state_idx_coefficients[idx]
* omega_d ** exp_pair[0]
* amp ** exp_pair[1]
)
return result

def _create_exponent_pair_idx_map(self) -> dict:
Expand Down Expand Up @@ -289,7 +304,10 @@ def _create_exponent_pair_idx_map(self) -> dict:

class DisplacedStateFit(DisplacedState):
def displaced_states_fit(
self, amp_idxs: list, ovlp_with_bare_states: np.ndarray, floquet_data: np.ndarray
self,
amp_idxs: list,
ovlp_with_bare_states: np.ndarray,
floquet_data: np.ndarray,
) -> np.ndarray:
"""Loop over all states and perform the fit for a given amplitude range.
Expand All @@ -311,9 +329,7 @@ def _fit_for_state_idx(array_state_idx: tuple[int, int]) -> np.ndarray:
# currently investigating. Asking for the state for amplitude values outside of
# the fit window should be done at your own peril.
array_idx, state_idx = array_state_idx
floquet_idx_data = floquet_data[
:, amp_idxs[0]: amp_idxs[1], array_idx, :
]
floquet_idx_data = floquet_data[:, amp_idxs[0]: amp_idxs[1], array_idx, :]
mask = ovlp_with_bare_states[:, :, array_idx].ravel()
omega_d_amp_data_slice = list(
self.drive_parameters.omega_d_amp_params(amp_idxs)
Expand All @@ -330,13 +346,18 @@ def _fit_for_state_idx(array_state_idx: tuple[int, int]) -> np.ndarray:
(self.hilbert_dim, num_coeffs), dtype=complex
)
if len(omega_d_amp_filtered) < len(self.exponent_pair_idx_map):
warnings.warn('Not enough data points to fit. Returning zeros for the fit', stacklevel=3)
warnings.warn(
'Not enough data points to fit. Returning zeros for the fit',
stacklevel=3,
)
return coefficient_matrix_for_amp_and_state
for state_idx_component in range(self.hilbert_dim):
floquet_idx_data_bare_component = floquet_idx_data[:, :, state_idx_component].ravel()
floquet_idx_data_bare_component = floquet_idx_data[
:, :, state_idx_component
].ravel()
floquet_component_filtered = floquet_idx_data_bare_component[
np.abs(mask) > self.options.overlap_cutoff
]
]
bare_same = state_idx_component == state_idx
bare_component_fit = self._fit_coefficients_for_component(
omega_d_amp_filtered, floquet_component_filtered, bare_same
Expand All @@ -352,7 +373,11 @@ def _fit_for_state_idx(array_state_idx: tuple[int, int]) -> np.ndarray:
parallel_map(self.options.num_cpus, _fit_for_state_idx, array_state_idxs)
)
return np.array(fit_data, dtype=complex).reshape(
(len(self.state_indices), self.hilbert_dim, len(self._create_exponent_pair_idx_map()))
(
len(self.state_indices),
self.hilbert_dim,
len(self._create_exponent_pair_idx_map()),
)
)

def _fit_coefficients_for_component(
Expand Down Expand Up @@ -440,10 +465,15 @@ def param_dict(self) -> dict:
"""Collect all attributes for writing to file, including derived ones."""
if self.init_data_to_save is None:
self.init_data_to_save = {}
return vars(self.options) | vars(self.drive_parameters) | self.init_data_to_save | {
'hilbert_dim': self.hilbert_dim,
'floquet_analysis_init': self.get_initdata(),
}
return (
vars(self.options)
| vars(self.drive_parameters)
| self.init_data_to_save
| {
'hilbert_dim': self.hilbert_dim,
'floquet_analysis_init': self.get_initdata(),
}
)

def __str__(self) -> str:
params = self.param_dict()
Expand Down Expand Up @@ -544,7 +574,9 @@ def bare_state_array(self) -> np.ndarray:
Used to specify initial bare states for the Blais branch analysis.
"""
return np.squeeze(
np.array([qt.basis(self.hilbert_dim, idx) for idx in range(self.hilbert_dim)])
np.array(
[qt.basis(self.hilbert_dim, idx) for idx in range(self.hilbert_dim)]
)
)

def _step_in_amp(
Expand Down Expand Up @@ -620,9 +652,7 @@ def run(self, filepath: str = 'tmp.h5py') -> dict:
# fit over the full range, using states identified by the fit over
# intermediate ranges
_displaced_state_overlaps = np.zeros(array_shape)
floquet_mode_data = np.zeros(
(*array_shape, self.hilbert_dim), dtype=complex
)
floquet_mode_data = np.zeros((*array_shape, self.hilbert_dim), dtype=complex)
avg_excitation = np.zeros(
(
len(self.drive_parameters.omega_d_values),
Expand All @@ -645,7 +675,7 @@ def run(self, filepath: str = 'tmp.h5py') -> dict:
hilbert_dim=self.hilbert_dim,
drive_parameters=self.drive_parameters,
state_indices=self.state_indices,
options=self.options
options=self.options,
)
previous_coefficients = np.array(
[
Expand All @@ -664,22 +694,21 @@ def run(self, filepath: str = 'tmp.h5py') -> dict:
amp_range_idx_final = len(self.drive_parameters.drive_amplitudes)
else:
amp_range_idx_final = (amp_range_idx + 1) * num_amp_pts_per_range
amp_idxs = [
amp_range_idx * num_amp_pts_per_range,
amp_range_idx_final,
]
amp_idxs = [amp_range_idx * num_amp_pts_per_range, amp_range_idx_final]
# now perform floquet mode calculation for amp_range_idx
# need to pass forward the floquet modes from the previous amp range
# which allow us to identify floquet modes that may have been displaced
# far from the origin
output = self._floquet_main_for_amp_range(
amp_idxs, displaced_state, previous_coefficients, prev_f_modes_arr
)
(max_overlap_data_for_range,
floquet_mode_data_for_range,
avg_excitation_for_range,
quasienergies_for_range,
prev_f_modes_arr) = output
(
max_overlap_data_for_range,
floquet_mode_data_for_range,
avg_excitation_for_range,
quasienergies_for_range,
prev_f_modes_arr,
) = output
max_overlap_data = self._place_into(
amp_idxs, max_overlap_data_for_range, max_overlap_data
)
Expand Down Expand Up @@ -836,4 +865,10 @@ def _run_floquet_and_calculate(
)
max_overlap_data = np.abs(floquet_mode_array[..., 0])
floquet_mode_data = floquet_mode_array[..., 1:]
return max_overlap_data, floquet_mode_data, all_avg_excitation, all_quasienergies, f_modes_last_amp
return (
max_overlap_data,
floquet_mode_data,
all_avg_excitation,
all_quasienergies,
f_modes_last_amp,
)
4 changes: 2 additions & 2 deletions tests/test_floquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@

from floquet import (
ChiacToAmp,
DisplacedState,
DriveParameters,
Options,
XiSqToAmp,
floquet_analysis,
floquet_analysis_from_file,
DriveParameters,
DisplacedState
)


Expand Down

0 comments on commit 9f5593d

Please sign in to comment.