Skip to content

Commit

Permalink
Fixes TARDIS data export and hashing using pickle (sadly)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewfullard committed Sep 10, 2024
1 parent ea11a98 commit e5e3bac
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 22 deletions.
2 changes: 1 addition & 1 deletion carsus/io/output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def to_hdf(self, fname):
total_checksum = hashlib.md5()
for key in f.keys():
# update the total checksum to sign the file
total_checksum.update(serialize_pandas_object(f[key]).to_buffer())
total_checksum.update(serialize_pandas_object(f[key]))

# save individual DataFrame/Series checksum
checksum = hash_pandas_object(f[key])
Expand Down
13 changes: 6 additions & 7 deletions carsus/io/output/levels_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,19 @@ def ingest_multiple_sources(self, attribute):
"""
gfall = getattr(self.gfall_reader, attribute)
gfall["ds_id"] = 2
sources = [gfall]

if self.chianti_reader is not None:
chianti = getattr(self.chianti_reader, attribute)
chianti["ds_id"] = 4
else:
chianti = pd.DataFrame(columns=gfall.columns)
sources.append(chianti)

if self.cmfgen_reader is not None:
cmfgen = getattr(self.cmfgen_reader, attribute)
cmfgen["ds_id"] = 5
else:
cmfgen = pd.DataFrame(columns=gfall.columns)
sources.append(cmfgen)

return pd.concat([gfall, chianti, cmfgen], sort=True)
return pd.concat(sources, sort=True)

# replace with functools.cached_property with Python > 3.8
@property
Expand Down Expand Up @@ -478,7 +477,7 @@ def create_levels_lines(
lines["f_ul"] = lines["gf"] / lines["g_u"]

# Calculate frequency
lines["nu"] = u.Quantity(lines["wavelength"], "AA").to("Hz", u.spectral())
lines["nu"] = u.Quantity(lines["wavelength"], "AA").to("Hz", u.spectral()).value

# Create Einstein coefficients
create_einstein_coeff(lines)
Expand All @@ -490,7 +489,7 @@ def create_levels_lines(

# Create and append artificial levels for fully ionized ions
artificial_fully_ionized_levels = create_artificial_fully_ionized(levels)
levels = levels.append(artificial_fully_ionized_levels, ignore_index=True)
levels = pd.concat([levels, artificial_fully_ionized_levels], ignore_index=True)
levels = levels.sort_values(["atomic_number", "ion_number", "level_number"])

self.levels = levels
Expand Down
2 changes: 1 addition & 1 deletion carsus/io/output/macro_atom.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_macro_atom(self):
("target_level_number", np.int64),
("transition_line_id", np.int64),
("transition_type", np.int64),
("transition_probability", np.float),
("transition_probability", np.float64),
]

for line_id, row in lines.iterrows():
Expand Down
4 changes: 2 additions & 2 deletions carsus/io/output/photo_ionization.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def cross_sections(self):

cross_sections["energy"] = u.Quantity(cross_sections["energy"], "Ry").to(
"Hz", equivalencies=u.spectral()
)
cross_sections["sigma"] = u.Quantity(cross_sections["sigma"], "Mbarn").to("cm2")
).value
cross_sections["sigma"] = u.Quantity(cross_sections["sigma"], "Mbarn").to("cm2").value
cross_sections["level_number"] = cross_sections["level_number"].astype("int")
cross_sections = cross_sections.rename(
columns={"energy": "nu", "sigma": "x_sect"}
Expand Down
10 changes: 5 additions & 5 deletions carsus/io/tests/test_cmfgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
def si1_reader():
return CMFGENReader.from_config(
"Si 0-1",
atomic_path="/tmp/atomic",
atomic_path="/home/afullard/carsus-data-cmfgen/atomic",
collisions=True,
cross_sections=True,
ionization_energies=True,
Expand Down Expand Up @@ -126,25 +126,25 @@ def test_CMFGENHydGauntBfParser(cmfgen_refdata_fname):
return parser.base


@pytest.mark.with_refdata
#@pytest.mark.with_refdata
@pytest.mark.array_compare(file_format="pd_hdf")
def test_reader_lines(si1_reader):
return si1_reader.lines


@pytest.mark.with_refdata
#@pytest.mark.with_refdata
@pytest.mark.array_compare(file_format="pd_hdf")
def test_reader_levels(si1_reader):
return si1_reader.levels


@pytest.mark.with_refdata
#@pytest.mark.with_refdata
@pytest.mark.array_compare(file_format="pd_hdf")
def test_reader_collisions(si1_reader):
return si1_reader.collisions


@pytest.mark.with_refdata
#@pytest.mark.with_refdata
@pytest.mark.array_compare(file_format="pd_hdf")
def test_reader_cross_sections_squeeze(si1_reader):
return si1_reader.cross_sections
Expand Down
2 changes: 1 addition & 1 deletion carsus/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
)
from carsus.util.selected import parse_selected_atoms, parse_selected_species

from carsus.util.hash import hash_pandas_object
from carsus.util.hash import serialize_pandas_object, hash_pandas_object
18 changes: 16 additions & 2 deletions carsus/util/hash.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
import pandas as pd
import hashlib
import pickle

def serialize_pandas_object(pd_object):
"""Serialize Pandas objects with Pickle.
Parameters
----------
pd_object : pandas.Series or pandas.DataFrame
Pandas object to be serialized with Pickle.
Returns
-------
Pickle serialized Python object.
"""
return pickle.dumps(pd_object)


def hash_pandas_object(pd_object, algorithm="md5"):
Expand Down Expand Up @@ -30,4 +44,4 @@ def hash_pandas_object(pd_object, algorithm="md5"):
else:
raise ValueError('algorithm not supported')

return hash_func(pd.util.hash_pandas_object(pd_object).values).hexdigest()
return hash_func(serialize_pandas_object(pd_object)).hexdigest()
5 changes: 2 additions & 3 deletions carsus/util/tests/test_hash.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import pytest
import hashlib
import pandas as pd
from carsus.util import hash_pandas_object


@pytest.mark.parametrize(
"values, md5",
[
([(0, 1), (1, 2), (2, 3), (3, 4)], "9a6b127b25"),
(["apple", "banana", "orange"], "924a349b83"),
([(0, 1), (1, 2), (2, 3), (3, 4)], "12b31eadd7"),
(["apple", "banana", "orange"], "89b33d7168"),
],
)
def test_hash_pd(values, md5):
Expand Down

0 comments on commit e5e3bac

Please sign in to comment.