Skip to content

Commit b616c6b

Browse files
committed
impl(wkt): custom encoding for i64 and u64
1 parent a0d101d commit b616c6b

File tree

4 files changed

+297
-0
lines changed

4 files changed

+297
-0
lines changed

src/wkt/src/internal.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
mod int32;
2121
pub use int32::I32;
2222

23+
#[macro_use]
24+
mod visitor_64;
25+
mod int64;
26+
pub use int64::I64;
27+
mod uint64;
28+
pub use uint64::U64;
29+
2330
pub struct F32;
2431
pub struct F64;
2532

src/wkt/src/internal/int64.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//! Implement custom serializers for `i64`.
16+
//!
17+
//! In ProtoJSON 64-bit integers can be serialized as strings, but can be
18+
//! deserialized from strings or numbers.
19+
20+
use serde::de::Unexpected::Other;
21+
22+
pub struct I64;
23+
24+
impl<'de> serde_with::DeserializeAs<'de, i64> for I64 {
25+
fn deserialize_as<D>(deserializer: D) -> Result<i64, D::Error>
26+
where
27+
D: serde::de::Deserializer<'de>,
28+
{
29+
deserializer.deserialize_any(I64Visitor)
30+
}
31+
}
32+
33+
visitor_64!(I64Visitor, i64, "a 64-bit signed integer");
34+
35+
impl serde_with::SerializeAs<i64> for I64 {
36+
fn serialize_as<S>(source: &i64, serializer: S) -> Result<S::Ok, S::Error>
37+
where
38+
S: serde::Serializer,
39+
{
40+
serializer.collect_str(source)
41+
}
42+
}
43+
44+
#[cfg(test)]
45+
mod test {
46+
use super::*;
47+
use anyhow::Result;
48+
use serde_json::{Value, json};
49+
use serde_with::{DeserializeAs, SerializeAs};
50+
use test_case::test_case;
51+
52+
#[test_case(0, 0)]
53+
#[test_case("0", 0; "zero string")]
54+
#[test_case("2.0", 2)]
55+
#[test_case(3e5, 300_000)]
56+
#[test_case(-4e4, -40_000)]
57+
#[test_case("5e4", 50_000)]
58+
#[test_case("-6e5", -600_000)]
59+
#[test_case(-42, -42)]
60+
#[test_case("-7", -7)]
61+
#[test_case(84, 84)]
62+
#[test_case(168.0, 168)]
63+
#[test_case("21", 21)]
64+
#[test_case(i64::MAX, i64::MAX; "max")]
65+
#[test_case(i64::MAX as f64, i64::MAX; "max as f64")]
66+
#[test_case(format!("{}", i64::MAX), i64::MAX; "max as string")]
67+
#[test_case(format!("{}.0", i64::MAX), i64::MAX; "max as f64 string")]
68+
#[test_case(i64::MIN, i64::MIN; "min")]
69+
#[test_case(i64::MIN as f64, i64::MIN; "min as f64")]
70+
#[test_case(format!("{}", i64::MIN), i64::MIN; "min as string")]
71+
#[test_case(format!("{}.0", i64::MIN), i64::MIN; "min as f64 string")]
72+
// Not quite a roundtrip test because we always serialize as numbers.
73+
fn deser_and_ser<T: serde::Serialize>(input: T, want: i64) -> Result<()> {
74+
let got = I64::deserialize_as(json!(input))?;
75+
assert_eq!(got, want);
76+
77+
let serialized = I64::serialize_as(&got, serde_json::value::Serializer)?;
78+
assert_eq!(serialized, json!(got.to_string()));
79+
Ok(())
80+
}
81+
82+
#[test_case(json!(i64::MAX as f64 * 2.0))]
83+
#[test_case(json!(i64::MIN as f64 * 2.0))]
84+
#[test_case(json!(format!("{}", i64::MAX as f64 * 2.0)))]
85+
#[test_case(json!(format!("{}", i64::MIN as f64 * 2.0)))]
86+
#[test_case(json!(format!("{}", i64::MAX as i128 + 1)); "MAX+1 as string")]
87+
#[test_case(json!(format!("{}", i64::MIN as i128 - 1)); "MIN-1 as string")]
88+
#[test_case(json!("abc"))]
89+
#[test_case(json!(123.4))]
90+
#[test_case(json!("234.5"))]
91+
#[test_case(json!("-345.6"))]
92+
#[test_case(json!({}))]
93+
fn deser_error(input: Value) {
94+
let got = I64::deserialize_as(input).unwrap_err();
95+
assert!(got.is_data(), "{got:?}");
96+
}
97+
}

src/wkt/src/internal/uint64.rs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
//! Implement custom serializers for `u64`.
16+
//!
17+
//! In ProtoJSON 64-bit integers can be serialized as either strings or numbers.
18+
19+
use serde::de::Unexpected::Other;
20+
21+
pub struct U64;
22+
23+
impl<'de> serde_with::DeserializeAs<'de, u64> for U64 {
24+
fn deserialize_as<D>(deserializer: D) -> Result<u64, D::Error>
25+
where
26+
D: serde::de::Deserializer<'de>,
27+
{
28+
deserializer.deserialize_any(U64Visitor)
29+
}
30+
}
31+
32+
visitor_64!(U64Visitor, u64, "a 64-bit unsigned integer");
33+
34+
impl serde_with::SerializeAs<u64> for U64 {
35+
fn serialize_as<S>(source: &u64, serializer: S) -> Result<S::Ok, S::Error>
36+
where
37+
S: serde::Serializer,
38+
{
39+
serializer.collect_str(source)
40+
}
41+
}
42+
43+
#[cfg(test)]
44+
mod test {
45+
use super::*;
46+
use anyhow::Result;
47+
use serde_json::{Value, json};
48+
use serde_with::{DeserializeAs, SerializeAs};
49+
use test_case::test_case;
50+
51+
#[test_case(0, 0)]
52+
#[test_case("0", 0; "zero string")]
53+
#[test_case("2.0", 2)]
54+
#[test_case(3e5, 300_000)]
55+
#[test_case("5e4", 50_000)]
56+
#[test_case(84, 84)]
57+
#[test_case(168.0, 168)]
58+
#[test_case("21", 21)]
59+
#[test_case(u64::MAX, u64::MAX; "max")]
60+
#[test_case(u64::MAX as f64, u64::MAX; "max as f64")]
61+
#[test_case(format!("{}", u64::MAX), u64::MAX; "max as string")]
62+
#[test_case(format!("{}.0", u64::MAX), u64::MAX; "max as f64 string")]
63+
#[test_case(u64::MIN, u64::MIN; "min")]
64+
#[test_case(u64::MIN as f64, u64::MIN; "min as f64")]
65+
#[test_case(format!("{}", u64::MIN), u64::MIN; "min as string")]
66+
#[test_case(format!("{}.0", u64::MIN), u64::MIN; "min as f64 string")]
67+
// Not quite a roundtrip test because we always serialize as numbers.
68+
fn deser_and_ser<T: serde::Serialize>(input: T, want: u64) -> Result<()> {
69+
let got = U64::deserialize_as(json!(input))?;
70+
assert_eq!(got, want);
71+
72+
let serialized = U64::serialize_as(&got, serde_json::value::Serializer)?;
73+
assert_eq!(serialized, json!(got.to_string()));
74+
Ok(())
75+
}
76+
77+
#[test_case(json!(u64::MAX as f64 * 2.0))]
78+
#[test_case(json!(format!("{}", u64::MAX as f64 * 2.0)))]
79+
#[test_case(json!(format!("{}", u64::MAX as i128 + 1)); "MAX+1 as string")]
80+
#[test_case(json!(format!("{}", u64::MIN as i128 - 1)); "MIN-1 as string")]
81+
#[test_case(json!("abc"))]
82+
#[test_case(json!(123.4))]
83+
#[test_case(json!("234.5"))]
84+
#[test_case(json!({}))]
85+
fn deser_error(input: Value) {
86+
let got = U64::deserialize_as(input).unwrap_err();
87+
assert!(got.is_data(), "{got:?}");
88+
}
89+
}

src/wkt/src/internal/visitor_64.rs

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
// Copyright 2025 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// https://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
macro_rules! visitor_64 {
16+
($name: ident, $t: ty, $msg: literal) => {
17+
struct $name;
18+
19+
impl serde::de::Visitor<'_> for $name {
20+
type Value = $t;
21+
22+
fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
23+
where
24+
E: serde::de::Error,
25+
{
26+
// ProtoJSON says that both strings and numbers are accepted.
27+
// The difficulty is dealing with integer strings that are just
28+
// outside the range for i64 or u64. Parsing as `f64` rounds
29+
// those numbers and may result in incorrectly accepting the
30+
// value.
31+
//
32+
// First try to parse the string as a i128, if it works and it
33+
// is in range, return that value.
34+
if let Ok(v) = value.parse::<i128>() {
35+
return self.visit_i128(v);
36+
}
37+
// Next, try to parse it as a `f64` number (all JSON numbers are
38+
// `f64`) and then try to parse the result as the target type.
39+
let number = value.parse::<f64>().map_err(E::custom)?;
40+
self.visit_f64(number)
41+
}
42+
43+
fn visit_i64<E>(self, value: i64) -> std::result::Result<Self::Value, E>
44+
where
45+
E: serde::de::Error,
46+
{
47+
match value {
48+
_ if value < <$t>::MIN as i64 => Err(self::value_error(value)),
49+
_ => Ok(value as Self::Value),
50+
}
51+
}
52+
53+
fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
54+
where
55+
E: serde::de::Error,
56+
{
57+
match value {
58+
_ if value > <$t>::MAX as u64 => Err(self::value_error(value)),
59+
_ => Ok(value as Self::Value),
60+
}
61+
}
62+
63+
fn visit_f64<E>(self, value: f64) -> std::result::Result<Self::Value, E>
64+
where
65+
E: serde::de::Error,
66+
{
67+
match value {
68+
_ if value < <$t>::MIN as f64 => Err(self::value_error(value)),
69+
_ if value > <$t>::MAX as f64 => Err(self::value_error(value)),
70+
_ if value.fract() != 0.0 => Err(self::value_error(value)),
71+
// In Rust floating point to integer conversions are
72+
// "rounded towards zero":
73+
// https://doc.rust-lang.org/reference/expressions/operator-expr.html#r-expr.as.numeric.float-as-int
74+
// Because we are in range, and the fractional part is zero,
75+
// this conversion is safe.
76+
_ => Ok(value as Self::Value),
77+
}
78+
}
79+
80+
fn visit_i128<E>(self, value: i128) -> std::result::Result<Self::Value, E>
81+
where
82+
E: serde::de::Error,
83+
{
84+
match value {
85+
_ if value < <$t>::MIN as i128 => Err(self::value_error(value)),
86+
_ if value > <$t>::MAX as i128 => Err(self::value_error(value)),
87+
_ => Ok(value as Self::Value),
88+
}
89+
}
90+
91+
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
92+
formatter.write_str($msg)
93+
}
94+
}
95+
96+
fn value_error<T, E>(value: T) -> E
97+
where
98+
T: std::fmt::Display,
99+
E: serde::de::Error,
100+
{
101+
E::invalid_value(Other(&format!("{value}")), &$msg)
102+
}
103+
};
104+
}

0 commit comments

Comments
 (0)