实现JVM中的JIT

Kevin Lynx 2017-03-09 00:00

在JVM中,JIT (Just-in-Time) 即时编译指的是在Java程序运行过程中JVM优化部分指令为本地指令,从而大幅提升性能。在上一篇文章写一个玩具Java虚拟机中实现了一个基本可以运行Java字节码的JVM。本篇文章描述我是如何在这个玩具JVM中实现JIT的。

推荐文章“How to JIT - an introduction”,介绍了JIT的基本实现原理。作者把JIT分为两个阶段:

  • 运行期生成机器代码(本地指令)
  • 执行机器代码

生成机器代码很好理解,就是一个JVM指令到机器指令的翻译;而执行机器代码,原理上是利用了OS提供了API可以分配可以执行的内存,然后往这块内存中写入机器码,从而实现运行期可以执行动态生成的机器码功能。

我们可以利用这个原理来实现JIT,但是未免太底层了点,需要做很多工作来完成这件事情。我们可以利用libjit来简化实现。这个作者博客里还有些libjit的教程,其中part 1值得阅读。 简单来说,libjit对机器指令做了抽象,利用它的API来描述一个函数包含了哪些指令,实现了什么功能。然后具体的指令生成以及指令执行则交给libjit完成。

例如以下使用libjit的代码:

1
2
3
4
5
6
7
8
// t = u
jit_insn_store(F, t, u); // 类似 mov 指令
// u = v
jit_insn_store(F, u, v);

// v = t % v
jit_value_t rem = jit_insn_rem(F, t, v); // 求余指令
jit_insn_store(F, v, rem);

所以,我们需要做的,就是将JVM的字节码,翻译成一堆libjit的API调用。但是我希望能够稍微做点抽象,我们写个翻译器,能够将JVM这种基于栈的指令,翻译成基于寄存器的指令,才方便后面无论是使用libjit还是直接翻译成机器码。

指令翻译

要将基于栈的指令翻译成基于寄存器的指令(类似),仔细想想主要解决两个问题:

  • 去除操作数栈
  • 跳转指令所需要的标签

去除操作数栈,我使用了一个简单办法,因为JVM中执行字节码时,我们是可以知道每条指令执行时栈的具体情况的,也就是每条指令执行时,它依赖的操作数在栈的哪个位置是清楚的。例如,假设某个函数开头有以下指令:

1
2
3
4
5
opcode [04] - 0000: iconst_1   # [1]
opcode [3C] - 0001: istore_1   # []
opcode [1B] - 0002: iload_1    # [1]
opcode [1A] - 0003: iload_0    # [1, N]
opcode [68] - 0004: imul       # [1 * N]

当执行imul指令时,就可以知道该指令使用栈s[0]、s[1]的值,做完计算后写回s[0]。所以,类似JVM中局部变量用数字来编号,我也为栈元素编号,这些编号的元素全部被视为局部变量,所以这些指令全部可以转换为基于局部变量的指令。为了和JVM中本身的局部变量统一,我们将栈元素编号从局部变量后面开始。假设以上函数有2个局部变量,那么栈元素从编号2开始,局部变量编号从0开始。以上指令可以翻译为:

1
2
3
4
5
mov 1, $2   # 常量1写入变量2
lod $2, $1  # 变量2写入变量1
lod $1, $2  # 变量1写回变量2
lod $0, $3  # 变量0写入变量3
mul $3, $2  # 变量3与变量2相乘,写回变量2

这里,我们定义了自己的中间指令集(IR),这个中间指令集存在的意义在于,在将来翻译为某个平台的机器码时,它比JVM的指令集更容易理解。中间指令集是一种抽象,方便基于它们使用libjit或其他手段翻译成机器码。

不过,我们看到上面的指令非常冗余。要优化掉这种冗余相对比较复杂,所以本文暂时不讨论这个问题。

这个中间指令基于局部变量的方式,是利于JIT下游做各种具体实现的,例如是否直接转换为通用寄存器,即一定范围的局部变量数是可以直接使用寄存器实现,超出该范围的局部变量则放在栈上,用栈模拟;或者全部用栈模拟。注意在机器指令中栈元素是可以直接偏移访问的,不同于“基于栈的虚拟机”中的栈。

以上指令,我们可以简单地为每条指令设定如何翻译为libjit的调用,例如mov指令:

1
2
3
4
static void build_mov(BuildContext* context, const Instruction* inst) {
  jit_value_t c = jit_value_create_nint_constant(context->F, jit_type_int, inst->op1); 
  jit_insn_store(context->F, context->vars[inst->op2], c);
}

例如mul指令:

1
2
3
4
5
static void build_mul(BuildContext* context, const Instruction* inst) {
  // context->vars就是前面说的局部变量表,包含了JVM中的局部变量及操作数栈
  jit_value_t tmp = jit_insn_mul(context->F, context->vars[inst->op1], context->vars[inst->op2]);
  jit_insn_store(context->F, context->vars[inst->op1], tmp);
}

接下来说另一个问题:跳转指令的标签。在机器指令中,跳转指令跳转的目标位置是一个绝对地址,或者像JVM中一样,是一个相对地址。但是在我们的中间指令集中,是没有地址的概念的,在翻译为机器指令时,也无法获取地址。所以,我们一般是增加了一个特殊指令label,用于打上一个标签,设置一个标签编号,相当于是一个地址。在后面的跳转指令中,则跳转的是这个标签编号。

所以,我们需要在翻译JVM指令到我们的中间指令时,识别出哪些地方需要打标签;并且在翻译跳转类指令时,翻译为跳转到某个编号的标签。

例如以下指令:

1
2
3
4
5
6
7
opcode [04] - 0000: iconst_1
opcode [3C] - 0001: istore_1
opcode [1B] - 0002: iload_1     # 会被调整,需要在此打标签
opcode [1A] - 0003: iload_0
...
opcode [1A] - 0010: iload_0
opcode [9D] - 0011: ifgt -9     # pc-9,也就是跳转到0002位置

为了打上标签,我们的翻译需要遍历两遍指令,第一遍用来找出所有标签,第二遍才做真正的翻译。

1
2
3
4
5
6
7
8
9
10
11
12
  // 该函数遍历所有指令,找出所有需要打标签的指令位置
  private List<Integer> createLabels(List<InstParser.Instruction> jbytecode) {
    List<Integer> labels = new LinkedList<>();
    for (InstParser.Instruction i : jbytecode) {
      LabelParser labelParser = labelParsers.get(i.opcode);
      if (labelParser != null) { // 不为空时表示是跳转指令
        int pc = labelParser.parse(i); // 不同的跳转指令地址解析不同,解析得到跳转的目标地址
        labels.add(pc); // 保存起来返回
      }
    }
    return labels;
  }

然后在翻译指令的过程中,发现当前翻译的指令地址是跳转的目标位置时,则生成标签指令:

1
2
3
4
5
6
7
8
9
   List<Integer> labels = createLabels(jbytecode);
   ...
   Iterator<InstParser.Instruction> it = jbytecode.iterator();
    while (it.hasNext()) {
      InstParser.Instruction inst = it.next();
      int label = labels.indexOf(inst.pc);
      if (label >= 0) {
        state.addIR(new Inst(op_label, label)); // 生成标签指令,label就是标签编号
      }

在处理跳转指令时,则填入标签编号:

1
2
3
4
5
6
7
translators.put(Opcode.op_ifgt, (state, inst, iterator) -> {
  short offset = (short)((inst.op1 << 8) + inst.op2);
  int pc = inst.pc + offset;
  int label = state.findLabel(pc); // 找到标签编号
  int var = state.popStack();
  state.addIR(new Inst(op_jmp_gt, var, label));
});

我们的中间指令集中,跳转指令和标签指令就为:

1
2
label #N            // 打上标签N
jmp_gt $var, #N     // 如果$var>0,跳转到标签#N

看下使用libjit如何翻译以上两条指令:

1
2
3
4
5
6
7
8
9
10
11
12
static void build_label(BuildContext* context, const Instruction* inst) {
  // 打上标签,inst->op1为标签编号N,对应写到context->labels[N]中
  jit_insn_label(context->F, &context->labels[inst->op1]);
}

static void build_jmp_gt(BuildContext* context, const Instruction* inst) {
  jit_value_t const0 = jit_value_create_nint_constant(context->F, jit_type_int, 0);
  // 是否>0
  jit_value_t cmp_v_0 = jit_insn_gt(context->F, context->vars[inst->op1], const0);
  // 大于0则跳转到标签inst->op2
  jit_insn_branch_if(context->F, cmp_v_0, &context->labels[inst->op2]);
}

代码贴得有点多,大概懂原理就行了。

在JIT中还有个很重要的过程,就是判定哪些代码需要被JIT。这里只是简单地尝试对每一个函数进行JIT,发现所有指令都能够被JIT时就JIT。

指令执行

在上一篇文章中,执行每个JVM函数时,都会有一个Frame与之关联。所以,在这里只要函数被JIT了,对应的帧就会包含被编译的代码,也就是libjit中的jit_function_t。在该Frame被执行时,就调用libjit执行该函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
  private void runNative() {
    int arg_cnt = getArgsCount();
    int[] args = new int[arg_cnt];
    for (int i = 0; i < arg_cnt; ++i) {
      if (mLocals[i].type != Slot.Type.NUM) throw new RuntimeException("only supported number arg in jit");
      args[i] = mLocals[i].i;
    }
    int ret = mJIT.invoke(args); // mJIT后面会看到,主要就是将参数以数组形式传递到libjit中,并做JIT函数调用
    mThread.popFrame();
    if (hasReturnType() && mThread.topFrame() != null) {
      mThread.topFrame().pushInt(ret); // 目前只支持int返回类型
    }
  }

实现

以上就是整个JIT的过程,主要工作集中于JVM指令到中间指令,中间指令到libjit API调用。整个实现包含以下模块:

1
2
3
4
5
6
7
8
9
10
+-----------+       +----------+
| ASM       |       | libjit   |
|           | <-----+ API call |
+-----------+       +----+-----+
                         ^
                         |
+-----------+       +----+-----+
|  JVM      |       |  IR code |
|  bytecode +-----> |          |
+-----------+       +----------+

JVM byte code及IR code的处理是在Java中完成的;处理完后将IR code输出为byte[],通过JNI调用包装好的C API。这个C API则是基于libjit,将IR code翻译为libjit的API调用。指令翻译完后调用libjit的API得到最终的ASM机器指令。

同样,要执行指令时,也是通过JNI调用这个C API。JNI交互全部包装在以下类中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public class ToyJIT {
  private long jitPtr = 0;

  public void initialize(byte[] bytes, int maxLocals, int maxLabels, int argCnt, int retType) {
    jitPtr = compile(bytes, maxLocals, maxLabels, argCnt, retType);
  }

  public int invoke(int... args) {
    return invoke(jitPtr, args);
  }

  static {
    System.loadLibrary("toyjit");
  }
  private static native long compile(byte[] bytes, int maxLocals, int maxLabels, int argCnt, int retType);
  private static native int invoke(long jitPtr, int[] args);

即,libtoyjit.so主要提供翻译接口compile及执行接口invoke。

性能对比

简单测试了下一个阶乘计算函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
  public static int fac2(int n) {
    int r = 1;
    do {
      r = r * n;
      n = n - 1;
    } while (n > 0);
    return r;
  }

 ...
    int i = 0;
    for (; i < 10000; ++i) {
        fac2(100);
    }
 ...

fac2函数会被JIT,测试发现不开启JIT时需要16秒,开启后1秒,差距还是很明显的。

最后奉上代码,toy_jit,就是前面说的C API部分,翻译IR到libjit API call,包装接口用于JNI调用。redhat 7.2下编译,需要先编译出libjit,我是直接clone的libjit master编译的。Java部分还是在toy_jvm中。

原文地址:http://codemacro.com/2017/03/09/toy-jit/
written byKevin Lynx posted athttp://codemacro.com

[返回] [原文链接]