Code using omegaconf to handle IO.
IO
source
load_config
def load_config(
file_path
):
source
config_to_dict
def config_to_dict(
config
):
source
save_dataclass_yaml
def save_dataclass_yaml(
data_obj, file_path
):
source
save_dict_yaml
def save_dict_yaml(
dict_obj, file_path
):
Test
@dataclass
class MyConfig:
target:str = class_to_str(OmegaConf)
clr_dim: int = 80
features: list[int]=None
c = MyConfig()
c.features = [1,2,3]
OmegaConf.structured(c)
{'target': 'omegaconf.omegaconf.OmegaConf', 'clr_dim': 80, 'features': [1, 2, 3]}
Object config load
Adapted from: https://github.com/Stability-AI/generative-models
source
get_obj_from_str
def get_obj_from_str(
string, reload:bool=False, invalidate_cache:bool=True
):
source
instantiate_from_config
def instantiate_from_config(
config
):
store_model_state_dict
def store_model_state_dict(
state_dict, save_path
):
source
load_model_state_dict
def load_model_state_dict(
save_path, device
):
Tensors and numpy
torch.serialization.DEFAULT_PROTOCOL
source
store_tensor
def store_tensor(
tensor, save_path, type:str='tensor'
):
source
load_tensor
def load_tensor(
save_path, device, type:str='tensor'
):
Back to top