Skip to content

[WIP] Add support for custom network classifier cgroup in sandbox #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: databricks-bazel-6.5.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ private Optional<VirtualCGroup> getCgroup(Spawn spawn, SpawnExecutionContext con
VirtualCGroup cgroup = null;
long memoryLimit = sandboxOptions.memoryLimitMb * 1024L * 1024L;
float cpuLimit = sandboxOptions.cpuLimit;
boolean requiresNetwork = false;

if (sandboxOptions.executionInfoLimit) {
ExecutionRequirements.ParseableRequirement requirement = ExecutionRequirements.RESOURCES;
Expand All @@ -198,7 +199,9 @@ private Optional<VirtualCGroup> getCgroup(Spawn spawn, SpawnExecutionContext con
requirement = ExecutionRequirements.RESOURCES;
String name = null;
Float value = null;

if (tag.equals(ExecutionRequirements.REQUIRES_NETWORK)) {
requiresNetwork = true;
}
String extras = requirement.parseIfMatches(tag);
if (extras != null) {
int index = extras.indexOf(":");
Expand Down Expand Up @@ -259,6 +262,13 @@ private Optional<VirtualCGroup> getCgroup(Spawn spawn, SpawnExecutionContext con
cgroup.cpu().setCpus(cpuLimit);
}

if (!requiresNetwork && sandboxOptions.cgroupNetCls != 0) {
if (cgroup == null) {
cgroup = VirtualCGroup.getInstance(this.reporter).child(scope);
}
cgroup.netCls().setNetCls(sandboxOptions.cgroupNetCls);
}

cgroups.put(context.getId(), Optional.ofNullable(cgroup));

return cgroups.get(context.getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,16 @@ public ImmutableSet<Path> getInaccessiblePaths(FileSystem fs) {
+ " Requires cgroups v1 or v2 and permissions for the users to the cgroups dir.")
public float cpuLimit;

@Option(
name = "experimental_sandbox_cgroup_net_cls",
defaultValue = "0",
documentationCategory = OptionDocumentationCategory.EXECUTION_STRATEGY,
effectTags = {OptionEffectTag.EXECUTION},
help =
"If set, any target not tagged with requires-network will run its actions "
+ "inside a sandbox with the given netcls for filtering by iptables firewalls.")
public int cgroupNetCls;

@Option(
name = "experimental_sandbox_execution_info_limit",
defaultValue = "false",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,8 @@ interface Cpu extends Controller {
interface CpuAcct extends Controller {
long getUsage() throws IOException;
}
interface NetCls extends Controller {
void setNetCls(int netCls) throws IOException;
int getNetCls() throws IOException;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
import com.google.devtools.build.lib.sandbox.cgroups.v1.LegacyCpu;
import com.google.devtools.build.lib.sandbox.cgroups.v1.LegacyCpuAcct;
import com.google.devtools.build.lib.sandbox.cgroups.v1.LegacyMemory;
import com.google.devtools.build.lib.sandbox.cgroups.v1.LegacyNetCls;
import com.google.devtools.build.lib.sandbox.cgroups.v2.UnifiedCpu;
import com.google.devtools.build.lib.sandbox.cgroups.v2.UnifiedMemory;
import com.google.devtools.build.lib.sandbox.cgroups.v2.UnifiedNetCls;

import javax.annotation.Nullable;
import java.io.BufferedReader;
Expand Down Expand Up @@ -53,6 +55,8 @@ public abstract class VirtualCGroup {
public abstract Controller.Memory memory();
@Nullable
public abstract Controller.CpuAcct cpuacct();
@Nullable
public abstract Controller.NetCls netCls();

public abstract ImmutableSet<Path> paths();

Expand Down Expand Up @@ -111,6 +115,7 @@ static VirtualCGroup create(File procMounts, File procCgroup, EventHandler repor
Controller.Memory memory = null;
Controller.Cpu cpu = null;
Controller.CpuAcct cpuacct = null;
Controller.NetCls netCls = null;
ImmutableSet.Builder<Path> paths = ImmutableSet.builder();

for (Mount m: mounts) {
Expand Down Expand Up @@ -157,6 +162,11 @@ static VirtualCGroup create(File procMounts, File procCgroup, EventHandler repor
logger.atInfo().log("Found cgroup v2 cpu controller at %s", cgroup);
cpu = new UnifiedCpu(cgroup);
break;
case "net_cls":
if (netCls != null) continue;
logger.atInfo().log("Found cgroup v2 net_cls controller at %s", cgroup);
netCls = new UnifiedNetCls(cgroup);
break;
}
}
} else {
Expand Down Expand Up @@ -188,14 +198,20 @@ static VirtualCGroup create(File procMounts, File procCgroup, EventHandler repor
logger.atInfo().log("Found cgroup v1 cpuacct controller at %s", cgroup);
cpuacct = new LegacyCpuAcct(cgroup);
break;
case "net_cls":
if (netCls != null) continue;
logger.atInfo().log("Found cgroup v1 net_cls controller at %s", cgroup);
netCls = new LegacyNetCls(cgroup);
break;
}
}
}
}

cpu = cpu != null ? cpu : Controller.getDefault(Controller.Cpu.class);
memory = memory != null ? memory : Controller.getDefault(Controller.Memory.class);
VirtualCGroup vcgroup = new AutoValue_VirtualCGroup(cpu, memory, cpuacct, paths.build());
netCls = netCls != null ? netCls : Controller.getDefault(Controller.NetCls.class);
VirtualCGroup vcgroup = new AutoValue_VirtualCGroup(cpu, memory, cpuacct, netCls, paths.build());
Runtime.getRuntime().addShutdownHook(new Thread(() -> vcgroup.delete()));
return vcgroup;
}
Expand All @@ -208,6 +224,7 @@ public void delete() {
public VirtualCGroup child(String name) throws IOException {
Controller.Cpu cpu = Controller.getDefault(Controller.Cpu.class);
Controller.Memory memory = Controller.getDefault(Controller.Memory.class);
Controller.NetCls netCls = null;
Controller.CpuAcct cpuacct = null;
ImmutableSet.Builder<Path> paths = ImmutableSet.builder();
if (memory() != null && memory().getPath() != null) {
Expand All @@ -232,7 +249,14 @@ public VirtualCGroup child(String name) throws IOException {
cpuacct = new LegacyCpuAcct(cgroup);
paths.add(cgroup);
}
VirtualCGroup child = new AutoValue_VirtualCGroup(cpu, memory, cpuacct, paths.build());
if (netCls() != null && netCls().getPath() != null) {
copyControllersToSubtree(netCls().getPath());
Path cgroup = netCls().getPath().resolve(name);
cgroup.toFile().mkdirs();
netCls = new LegacyNetCls(cgroup);
paths.add(cgroup);
}
VirtualCGroup child = new AutoValue_VirtualCGroup(cpu, memory, cpuacct, netCls, paths.build());
this.children.add(child);
return child;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.google.devtools.build.lib.sandbox.cgroups.v1;

import com.google.devtools.build.lib.sandbox.cgroups.Controller;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;

public class LegacyNetCls implements Controller.NetCls {
private final Path path;

@Override
public Path getPath() throws IOException {
return path;
}

public LegacyNetCls(Path path) {
this.path = path;
}

@Override
public Path statFile() throws IOException {
return path.resolve("net_cls.stat");
}

@Override
public void setNetCls(int netCls) throws IOException {
Files.writeString(path.resolve("net_cls.classid"), Integer.toString(netCls));
}

@Override
public int getNetCls() throws IOException {
return Integer.parseInt(Files.readString(path.resolve("net_cls.classid")).trim());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.google.devtools.build.lib.sandbox.cgroups.v2;

import com.google.devtools.build.lib.sandbox.cgroups.Controller;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;

public class UnifiedNetCls implements Controller.NetCls {
private final Path path;

public UnifiedNetCls(Path path) throws IOException {
this.path = path;
}

@Override
public Path getPath() {
return path;
}

@Override
public Path statFile() throws IOException {
return path.resolve("net_cls.stat");
}

@Override
public void setNetCls(int netCls) throws IOException {
Files.writeString(path.resolve("net_cls.classid"), Integer.toString(netCls));
}

@Override
public int getNetCls() throws IOException {
return Integer.parseInt(Files.readString(path.resolve("net_cls.classid")).trim());
}
}