diff --git a/tests/test_core.py b/tests/test_core.py index 688eca6..5261e9b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -222,29 +222,42 @@ def test_dataset_empty_constructor(): def test_dataset_example(ds): ds_schema = DatasetSchema( - { + data_vars={ 'foo': DataArraySchema(name='foo', dtype=np.int32, dims=['x']), 'bar': DataArraySchema(name='bar', dtype=np.floating, dims=['x', 'y']), - } + }, + coords={'x': DataArraySchema(name='x', dtype=np.int64, dims=['x'])}, + attrs={}, ) jsonschema.validate(ds_schema.json, ds_schema._json_schema) assert list(ds_schema.json['data_vars'].keys()) == ['foo', 'bar'] + assert list(ds_schema.json['coords']['coords'].keys()) == ['x'] ds_schema.validate(ds) - ds['foo'] = ds.foo.astype('float32') + ds2 = ds.copy() + ds2['foo'] = ds2.foo.astype('float32') with pytest.raises(SchemaError, match='dtype'): - ds_schema.validate(ds) + ds_schema.validate(ds2) - ds = ds.drop_vars('foo') + ds2 = ds2.drop_vars('foo') with pytest.raises(SchemaError, match='variable foo'): - ds_schema.validate(ds) + ds_schema.validate(ds2) + + ds3 = ds.copy() + ds3['x'] = ds3.x.astype('float32') + with pytest.raises(SchemaError, match='dtype'): + ds_schema.validate(ds3) + + ds3 = ds3.drop_vars('x') + with pytest.raises(SchemaError, match='coords has missing keys'): + ds_schema.validate(ds3) # json roundtrip rt_schema = DatasetSchema.from_json(ds_schema.json) assert isinstance(rt_schema, DatasetSchema) - rt_schema.json == ds_schema.json + assert rt_schema.json == ds_schema.json def test_checks_ds(ds): diff --git a/xarray_schema/dataset.py b/xarray_schema/dataset.py index 9be35ee..a7d2289 100644 --- a/xarray_schema/dataset.py +++ b/xarray_schema/dataset.py @@ -47,9 +47,9 @@ def from_json(cls, obj: dict): k: DataArraySchema.from_json(v) for k, v in obj['data_vars'].items() } if 'coords' in obj: - kwargs['coords'] = {k: CoordsSchema.from_json(v) for k, v in obj['coords'].items()} + kwargs['coords'] = CoordsSchema.from_json(obj['coords']) if 'attrs' in obj: - kwargs['attrs'] = {k: AttrsSchema.from_json(v) for k, v in obj['attrs'].items()} + kwargs['attrs'] = AttrsSchema.from_json(obj['attrs']) return cls(**kwargs) @@ -79,8 +79,8 @@ def validate(self, ds: xr.Dataset) -> None: else: da_schema.validate(ds.data_vars[key]) - if self.coords is not None: # pragma: no cover - raise NotImplementedError('coords schema not implemented yet') + if self.coords is not None: + self.coords.validate(ds.coords) if self.attrs: self.attrs.validate(ds.attrs)