Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RNN example to perform DGA detection with LSTMs #240

Merged
merged 8 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .ci/macos-steps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ steps:
mkdir deps/
cd deps/

# Install Armadillo 9.800.1 (the oldest supported version).
curl -O http://files.mlpack.org/armadillo-9.800.1.tar.gz
tar xvzf armadillo-9.800.1.tar.gz
cd armadillo-9.800.1
# Install Armadillo 10.8.2 (the oldest supported version).
curl -O http://files.mlpack.org/armadillo-10.8.2.tar.gz
tar xvzf armadillo-10.8.2.tar.gz
cd armadillo-10.8.2
cmake .
make
sudo make install
cd ../
rm -rf armadillo-9.800.1/
rm -rf armadillo-10.8.2/

# Build and install the latest version of ensmallen.
curl -O https://www.ensmallen.org/files/ensmallen-latest.tar.gz
Expand Down
11 changes: 6 additions & 5 deletions cpp/kmeans/dominant-colors/dominant-colors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ void GetColorBarData(std::string& values,
colors = colorsString.str();
}

void dominantColors(std::string PathToImage)
void dominantColors(std::string PathToImage, std::string PathToColorBars)
{
// Load the example image.
arma::Mat<unsigned char> imageMatrix;
Expand Down Expand Up @@ -249,12 +249,13 @@ void dominantColors(std::string PathToImage)
GetColorBarData(values, colors, cluster, assignments, centroids);

// Show the dominant colors.
StackedBar(values, colors, "jurassic-park-colors.png");
StackedBar(values, colors, PathToColorBars);
}

int main()
{
dominantColors("../../../data/jurassic-park.png");
dominantColors("../../../data/the-grand-budapest-hotel.png");
dominantColors("../../../data/the-godfather.png");
dominantColors("../../../data/jurassic-park.png", "jurassic-park-colors.png");
dominantColors("../../../data/the-grand-budapest-hotel.png",
"the-grand-budapest-hotel-colors.png");
dominantColors("../../../data/the-godfather.png", "the-godfather-colors.png");
}
50 changes: 50 additions & 0 deletions cpp/lstm/dga_detection/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# This is a simple Makefile used to build the example source code.
# This example might requires some modifications in order to work correctly on
# your system.
# If you're not using the Armadillo wrapper, replace `armadillo` with linker commands
# for the BLAS and LAPACK libraries that you are using.

TARGET1 := lstm_dga_detection_train
SRC1 := lstm_dga_detection_train.cpp
TARGET2 := lstm_dga_detection_predict
SRC2 := lstm_dga_detection_predict.cpp
LIBS_NAME := armadillo

CXX := g++
CXXFLAGS += -std=c++17 -Wall -Wextra -O3 -DNDEBUG -fopenmp
# Use these CXXFLAGS instead if you want to compile with debugging symbols and
# without optimizations.
#CXXFLAGS += -std=c++17 -Wall -Wextra -g -O0 -fopenmp
LDFLAGS += -fopenmp
# Add header directories for any includes that aren't on the
# default compiler search path.
INCLFLAGS := -I .
INCLFLAGS += -I/home/ryan/src/mlpack/src/
INCLFLAGS += -I/home/ryan/src/ensmallen/include/
# If you have mlpack or ensmallen installed somewhere nonstandard, uncomment and
# update the lines below.
# INCLFLAGS += -I/path/to/mlpack/include/
# INCLFLAGS += -I/path/to/ensmallen/include/
CXXFLAGS += $(INCLFLAGS)

OBJS1 := $(SRC1:.cpp=.o)
OBJS2 := $(SRC2:.cpp=.o)
LIBS := $(addprefix -l,$(LIBS_NAME))
CLEAN_LIST := $(TARGET1) $(TARGET2) $(OBJS1) $(OBJS2)

# default rule
default: all

$(TARGET1): $(OBJS1)
$(CXX) $(CXXFLAGS) $(OBJS1) -o $(TARGET1) $(LDFLAGS) $(LIBS)

$(TARGET2): $(OBJS2)
$(CXX) $(CXXFLAGS) $(OBJS2) -o $(TARGET2) $(LDFLAGS) $(LIBS)

.PHONY: all
all: $(TARGET1) $(TARGET2)

.PHONY: clean
clean:
@echo CLEAN $(CLEAN_LIST)
@rm -f $(CLEAN_LIST)
151 changes: 151 additions & 0 deletions cpp/lstm/dga_detection/lstm_dga_detection_predict.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/**
* @file lstm_dga_detection_predict.cpp
* @author Ryan Curtin
*
* Given two trained DGA detection RNNs, make predictions. Domains should be
* input on stdin.
*
* Predictions are made by computing the likelihood of a domain coming from the
* benign model and from the malicious model. The predicted class is benign, if
* the likelihood of the domain coming from the benign model is higher (and
* malicious if vice versa).
*
* This is called the generalized likelihood ratio test (GLRT).
*
* To keep the model small and the code fast, we use `float` as a datatype
* instead of the default `double`.
*/
#include <mlpack.hpp>

// To keep compilation time and program size down, we only register
// serialization for layers used in our RNNs. Plus, given that we are using
// floats instead of doubles for our data, we need to register the layers for
// serialization either individually or all of them with
// CEREAL_REGISTER_MLPACK_LAYERS() (commented out below).
CEREAL_REGISTER_TYPE(mlpack::Layer<arma::fmat>);
CEREAL_REGISTER_TYPE(mlpack::MultiLayer<arma::fmat>);
CEREAL_REGISTER_TYPE(mlpack::RecurrentLayer<arma::fmat>);
CEREAL_REGISTER_TYPE(mlpack::LSTMType<arma::fmat>);
CEREAL_REGISTER_TYPE(mlpack::LinearType<arma::fmat>);
CEREAL_REGISTER_TYPE(mlpack::LogSoftMaxType<arma::fmat>);

// This will register all mlpack layers with the arma::fmat type.
// It is useful for playing around with the network architecture, but can make
// compilation time a lot longer. Comment out the individual
// CEREAL_REGISTER_TYPE() calls above if you use the line below.
//
// CEREAL_REGISTER_MLPACK_LAYERS(arma::fmat);

using namespace mlpack;
using namespace std;

// Utility function: map a character to the one-hot encoded dimension that
// represents it. Characters will be assigned to one of 38 dimensions. If an
// incorrect character is given, size_t(-1) is returned.
inline size_t CharToDim(const char inC)
{
char c = tolower(inC);
if (c >= 'a' && c <= 'z')
return size_t(c - 'a');
else if (c >= '0' && c <= '9')
return 26 + size_t(c - '0');
else if (c == '-')
return 36;
else if (c == '.')
return 37;
else
return size_t(-1);
}

// Utility function: turn a domain string into an arma::cube.
inline void PrepareString(const string& domain,
arma::fcube& data,
arma::fcube& response)
{
data.zeros(39, 1, domain.size());
response.set_size(1, 1, domain.size());

// One-hot encode each character.
for (size_t t = 0; t < domain.size(); ++t)
{
const size_t dim = CharToDim(domain[t]);
if (dim == size_t(-1))
{
cerr << "Domain '" << domain << "' has invalid character '" << domain[t]
<< "'!" << endl;
exit(1);
}

data(dim, 0, t) = 1.0;

if (t > 0)
response(0, 0, t - 1) = dim;
}

// Set end-of-input response.
response(0, 0, domain.size() - 1) = 38;
}

// Compute the likelihood that the string came from the model, given the
// predicted outputs of the model.
inline float ComputeLikelihood(const arma::fcube& predictions,
const arma::fcube& response)
{
float likelihood = 0.0;
for (size_t t = 0; t < response.n_slices; ++t)
likelihood += predictions((size_t) response(0, 0, t), 0, t);

return likelihood;
}

using namespace mlpack;
using namespace std;

int main(int argc, char** argv)
{
if (argc != 3)
{
cerr << "Usage: " << argv[0] << " benign_model.bin malicious_model.bin"
<< endl;
cerr << " - Train a model with the lstm_dga_detection_train program."
<< endl;
}

// First load the model.
RNN<NegativeLogLikelihoodType<arma::fmat>, RandomInitialization, arma::fmat>
benignModel, maliciousModel;
data::Load(argv[1], "lstm_model", benignModel, true /* fatal on failure */);
data::Load(argv[2], "lstm_model", maliciousModel, true);

// Now enter a loop where we read domains from stdin and then make
// predictions.
arma::fcube input, response, benignOutput, maliciousOutput;
while (true)
{
string line;
getline(cin, line);

if (cin.eof())
{
// The user has terminated the program.
return 0;
}

// Prepare the data for prediction and then make predictions with both
// models.
PrepareString(line, input, response);

// Now compute prediction.
benignModel.Predict(input, benignOutput);
maliciousModel.Predict(input, maliciousOutput);

const float benignLikelihood = ComputeLikelihood(benignOutput, response);
const float maliciousLikelihood = ComputeLikelihood(maliciousOutput,
response);

if (benignLikelihood > maliciousLikelihood)
cout << "benign" << endl;
else
cout << "malicious" << endl;
}
}
Loading
Loading