@@ -58,50 +58,50 @@ void rope_
58
58
);
59
59
}
60
60
61
- long gen_mrope_pos_ids
61
+ int64_t gen_mrope_pos_ids
62
62
(
63
63
torch::Tensor mrope_pos_ids,
64
64
torch::Tensor ids,
65
65
int merge_size,
66
- const std::vector<std::tuple<long , long >> &spans,
67
- const std::vector<std::tuple<long , long , long >> &grids
66
+ const std::vector<std::tuple<int64_t , int64_t >> &spans,
67
+ const std::vector<std::tuple<int64_t , int64_t , int64_t >> &grids
68
68
)
69
69
{
70
70
int max_length = mrope_pos_ids.size (1 );
71
71
int in_length = ids.size (0 );
72
72
73
- long * in_ids = (long *) ids.data_ptr ();
74
- long * pos_ids = (long *) mrope_pos_ids.data_ptr ();
73
+ int64_t * in_ids = (int64_t *) ids.data_ptr ();
74
+ int64_t * pos_ids = (int64_t *) mrope_pos_ids.data_ptr ();
75
75
76
- long * out_t = pos_ids;
77
- long * out_h = pos_ids + max_length;
78
- long * out_w = pos_ids + 2 * max_length;
76
+ int64_t * out_t = pos_ids;
77
+ int64_t * out_h = pos_ids + max_length;
78
+ int64_t * out_w = pos_ids + 2 * max_length;
79
79
80
- long base_t = 0 ;
81
- long next_base_t = 0 ;
80
+ int64_t base_t = 0 ;
81
+ int64_t next_base_t = 0 ;
82
82
83
83
for (int i = 0 ; i < max_length; ++i)
84
84
{
85
85
bool is_emb = false ;
86
86
if (i < in_length)
87
87
{
88
- long id = in_ids[i];
88
+ int64_t id = in_ids[i];
89
89
90
90
for (int j = 0 ; j < spans.size (); ++j)
91
91
{
92
- long span_start = std::get<0 >(spans[j]);
93
- long span_end = std::get<1 >(spans[j]);
94
- long span = span_end - span_start;
92
+ int64_t span_start = std::get<0 >(spans[j]);
93
+ int64_t span_end = std::get<1 >(spans[j]);
94
+ int64_t span = span_end - span_start;
95
95
if (id >= span_start && id < span_end)
96
96
{
97
97
is_emb = true ;
98
- long k = id - span_start;
99
- long grid_t = std::get<0 >(grids[j]);
100
- long grid_h = std::get<1 >(grids[j]) / (long )merge_size;
101
- long grid_w = std::get<2 >(grids[j]) / (long )merge_size;
102
- long k_t = base_t + (k / grid_w / grid_h) % grid_t ;
103
- long k_h = base_t + (k / grid_w) % grid_h;
104
- long k_w = base_t + k % grid_w;
98
+ int64_t k = id - span_start;
99
+ int64_t grid_t = std::get<0 >(grids[j]);
100
+ int64_t grid_h = std::get<1 >(grids[j]) / (int64_t )merge_size;
101
+ int64_t grid_w = std::get<2 >(grids[j]) / (int64_t )merge_size;
102
+ int64_t k_t = base_t + (k / grid_w / grid_h) % grid_t ;
103
+ int64_t k_h = base_t + (k / grid_w) % grid_h;
104
+ int64_t k_w = base_t + k % grid_w;
105
105
*out_t ++ = k_t ;
106
106
*out_h++ = k_h;
107
107
*out_w++ = k_w;
0 commit comments