Skip to content

Commit

Permalink
kafl/tdx: implement atomic cache for virtio
Browse files Browse the repository at this point in the history
When fuzzing TDX, the endianness conversions from the bounce buffer
results in unique values every time, which is impossible since the
buffer is copied and not modified until invalidated later. IE a read
at offset X within the buffer should always yield the same value but
currently results in a new random fuzz value.

To correct this, implement a cache for endianness conversions from the
bounce buffer based on the unique device pointer. Just cache one u64 since
all the values are smaller than u64. Invalidate the cache when the buffer
is invalidated.

Implement this cache within the virtio_device struct so each
virtio_device gets its own cache and use an atomic. This avoids the need
to share a global cache and any associated locking.

Signed-off-by: William Roberts <william.c.roberts@intel.com>
  • Loading branch information
William Roberts committed Apr 13, 2023
1 parent 211f74a commit 00eea52
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 78 deletions.
16 changes: 12 additions & 4 deletions arch/x86/include/asm/tdx.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,24 @@ enum tdx_fuzz_loc {
TDX_FUZZ_MAX
};

/* forward declare to avoid dependence on virtio.h and ensuing cyclic dependency on device.h */
struct virtio_device;

#if defined(CONFIG_TDX_FUZZ) || defined(CONFIG_TDX_FUZZ_KAFL)
u64 tdx_fuzz(u64 var, uintptr_t addr, int size, enum tdx_fuzz_loc loc);
u64 tdx_fuzz_device_cache(struct device *dev, u64 orig_var, uintptr_t addr, int size, enum tdx_fuzz_loc type);
void tdx_fuzz_device_cache_invalidate(struct device *dev);
bool tdx_fuzz_err(enum tdx_fuzz_loc loc);
void tdx_fuzz_virtio_cache_init(struct virtio_device *vdev);
u64 tdx_fuzz_virtio_cache_get_64(struct virtio_device *vdev, u64 orig_var);
void tdx_fuzz_virtio_cache_refresh(struct device *dev);
#else
static inline u64 tdx_fuzz(u64 var, uintptr_t addr, int size, enum tdx_fuzz_loc loc) { return var; };
static inline u64 tdx_fuzz_device_cache(struct device *dev, u64 var, uintptr_t addr, int size, enum tdx_fuzz_loc loc) { return var; };
static inline void tdx_fuzz_device_cache_invalidate(struct device *dev) {}
static inline bool tdx_fuzz_err(enum tdx_fuzz_loc loc) { return false; }
static inline void tdx_fuzz_virtio_cache_init(struct virtio_device *vdev) { }
static inline u64 tdx_fuzz_virtio_cache_get_64(struct virtio_device *vdev, u64 orig_var)
{
return orig_var;
}
static inline void tdx_fuzz_virtio_cache_refresh(struct device *dev) { }
#endif

/*
Expand Down
89 changes: 19 additions & 70 deletions arch/x86/kernel/kafl-agent.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
#include <linux/memblock.h>
#include <linux/kprobes.h>
#include <linux/string.h>
#include <linux/hashtable.h>
#include <linux/mutex.h>
#include <linux/atomic.h>
#include <linux/device.h>
#include <linux/virtio.h>
#include <asm/tdx.h>
#include <asm/trace/tdx.h>
#include <asm-generic/sections.h>
Expand Down Expand Up @@ -78,9 +79,6 @@ u8 *ob_buf;
u32 ob_num;
u32 ob_pos;

DEFINE_MUTEX(virtio_cache_mutex);
DEFINE_HASHTABLE(virtio_cache, 10);

const char *kafl_event_name[KAFL_EVENT_MAX] = {
"KAFL_ENABLE",
"KAFL_START",
Expand Down Expand Up @@ -535,79 +533,30 @@ u64 tdx_fuzz(u64 orig_var, uintptr_t addr, int size, enum tdx_fuzz_loc type)
}
EXPORT_SYMBOL(tdx_fuzz);

struct virtio_cache_node {
#define MAX_VIRTIO_CACHE_SIZE 8
u64 fuzz_data;
struct device *dev;
struct hlist_node node;
};

/* Thoughts: it could be better to embed the cache into the device structure on an ifdef */
u64 tdx_fuzz_device_cache(struct device *dev, u64 orig_var, uintptr_t addr, int size, enum tdx_fuzz_loc type)
void tdx_fuzz_virtio_cache_init(struct virtio_device *vdev)
{
struct hlist_node *cur;
struct virtio_cache_node *got, *new_node;
// TODO is returning orig var indicate error?
u64 ret_data = orig_var;

if (size > MAX_VIRTIO_CACHE_SIZE)
return ret_data;

mutex_lock(&virtio_cache_mutex);
/*
* note: to keep this simple just use cache no matter the size (up to the max).
* For more granularity the key could be generated based on device pointer + size
* and use xxhash to generate the key.
*/
hash_for_each_possible_safe(virtio_cache, got, cur, node, (unsigned long)dev) {
/* Handle hash collisions and ignore if not expected */
if (dev != got->dev)
continue;

/* match use the cache */
ret_data = got->fuzz_data;
goto out;
}

/*
* NOT in the cache, add it until swiotlb_bounce called on device
*/
new_node = kmalloc(sizeof(*new_node), GFP_ATOMIC);
if (!new_node) {
goto out;
}

// TODO does this mean fail if orig var == returned var?
if (tdx_fuzz(0, &new_node->fuzz_data, sizeof(new_node->fuzz_data), type) == 0) {
kfree(new_node);
goto out;
}
u64 data = tdx_fuzz(0, (uintptr_t)&data, sizeof(data), TDX_FUZZ_VIRTIO);

hash_add(virtio_cache, new_node, (unsigned long)dev);

ret_data = new_node->fuzz_data;
out:
mutex_unlock(&virtio_cache_mutex);
return ret_data;
atomic64_set(&vdev->tdx.fuzz_data, data);
}
EXPORT_SYMBOL(tdx_fuzz_device_cache);
EXPORT_SYMBOL(tdx_fuzz_virtio_cache_init);

void tdx_fuzz_device_cache_invalidate(struct device *dev)
u64 tdx_fuzz_virtio_cache_get_64(struct virtio_device *vdev, u64 orig_var)
{
int i;
struct hlist_node *cur;
struct virtio_cache_node *got, *new_node;

mutex_lock(&virtio_cache_mutex);
/* orig_var needed for signature when fuzzing is disabled */
(void)orig_var;
return atomic64_read(&vdev->tdx.fuzz_data);
}
EXPORT_SYMBOL(tdx_fuzz_virtio_cache_get_64);

hash_for_each_safe(virtio_cache, i, cur, got, node) {
hash_del(&got->node);
kfree(got);
}
void tdx_fuzz_virtio_cache_refresh(struct device *dev)
{
if (!is_virtio_device(dev))
return;

mutex_unlock(&virtio_cache_mutex);
tdx_fuzz_virtio_cache_init(dev_to_virtio(dev));
}
EXPORT_SYMBOL(tdx_fuzz_device_cache);
EXPORT_SYMBOL(tdx_fuzz_virtio_cache_refresh);

bool tdx_fuzz_err(enum tdx_fuzz_loc type)
{
Expand Down
2 changes: 2 additions & 0 deletions drivers/virtio/virtio.c
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,8 @@ int register_virtio_device(struct virtio_device *dev)
INIT_LIST_HEAD(&dev->vqs);
spin_lock_init(&dev->vqs_list_lock);

tdx_fuzz_virtio_cache_init(dev);

/*
* device_add() causes the bus infrastructure to look for a matching
* driver.
Expand Down
5 changes: 5 additions & 0 deletions include/linux/virtio.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,11 @@ struct virtio_device {
const struct vringh_config_ops *vringh_config;
struct list_head vqs;
u64 features;
#ifdef CONFIG_TDX_FUZZ_KAFL_VIRTIO
struct {
atomic64_t fuzz_data;
} tdx;
#endif
void *priv;
};

Expand Down
6 changes: 3 additions & 3 deletions include/linux/virtio_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ static inline bool virtio_is_little_endian(struct virtio_device *vdev)
static inline u16 virtio16_to_cpu(struct virtio_device *vdev, __virtio16 val)
{
u16 ret = __virtio16_to_cpu(virtio_is_little_endian(vdev), val);
return tdx_fuzz_cached(&vdev->dev, ret, 0, sizeof(ret), TDX_FUZZ_VIRTIO);
return tdx_fuzz_virtio_cache_get_64(vdev, ret);
}

static inline __virtio16 cpu_to_virtio16(struct virtio_device *vdev, u16 val)
Expand All @@ -289,7 +289,7 @@ static inline __virtio16 cpu_to_virtio16(struct virtio_device *vdev, u16 val)
static inline u32 virtio32_to_cpu(struct virtio_device *vdev, __virtio32 val)
{
u32 ret = __virtio32_to_cpu(virtio_is_little_endian(vdev), val);
return tdx_fuzz_cached(&vdev->dev, ret, 0, sizeof(ret), TDX_FUZZ_VIRTIO);
return tdx_fuzz_virtio_cache_get_64(vdev, ret);
}

static inline __virtio32 cpu_to_virtio32(struct virtio_device *vdev, u32 val)
Expand All @@ -300,7 +300,7 @@ static inline __virtio32 cpu_to_virtio32(struct virtio_device *vdev, u32 val)
static inline u64 virtio64_to_cpu(struct virtio_device *vdev, __virtio64 val)
{
u64 ret = __virtio64_to_cpu(virtio_is_little_endian(vdev), val);
return tdx_fuzz_cached(&vdev->dev, ret, 0, sizeof(ret), TDX_FUZZ_VIRTIO);
return tdx_fuzz_virtio_cache_get_64(vdev, ret);
}

static inline __virtio64 cpu_to_virtio64(struct virtio_device *vdev, u64 val)
Expand Down
2 changes: 1 addition & 1 deletion kernel/dma/swiotlb.c
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ static void swiotlb_bounce(struct device *dev, phys_addr_t tlb_addr, size_t size
if (orig_addr == INVALID_PHYS_ADDR)
return;

tdx_fuzz_device_cache_invalidate(dev);
tdx_fuzz_virtio_cache_refresh(dev);

tlb_offset = tlb_addr & (IO_TLB_SIZE - 1);
orig_addr_offset = swiotlb_align_offset(dev, orig_addr);
Expand Down

0 comments on commit 00eea52

Please sign in to comment.