haripo.com

PyO3 で Python から Rust を利用する

Python の機械学習ライブラリ資産を活かしつつ、計算コストの高いロジックを Rust で書くために調べた内容をまとめます。

PyO3

Rust でビルドしたライブラリを Python から利用可能にするために PyO3 を利用します。まだ安定版ではないので、これ以降の記述は PyO3 v0.6.0 に基づいていることに注意してください。

競合として rust-cpython というのもあるのですが、rust-numpy が rust-cpython から PyO3 に移行したとのことなので、PyO3 を選択しました。比較については https://pyo3.rs/master/rust-cpython.html が参考になります(PyO3 側のドキュメントですが)。注意しなければならない PyO3 の大きなデメリットは Rust の nightly と Python 3.5 以降が要求される点です。逆に rust-cpython の要求は Python 2.7 or 3.3 ~ 3.7, Rust 1.25.0 なので、バージョンに制約がある場合は rust-cpython が有力かと思います。

インストールと動作確認

公式の string_sum サンプルを動作確認してみます。Rust の nighty が必要です。

# rustc -V
rustc 1.35.0-nightly (99da733f7 2019-04-12)

cargo init して、pyo3 への依存を追加します。crate-type は cdylib とします。

# Cargo.toml
[package]
name = "***"
version = "***"
authors = ["***"]
edition = "2018"

[lib]
name = "string_sum"
crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.6.0", features = ["extension-module"] }

lib.rs をこんなかんじに書きます。

use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

#[pyfunction]
fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
    Ok((a + b).to_string())
}

#[pymodule]
fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
    m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
    Ok(())
}

cargo —release します。mac でビルドするときは .cargo/configを作っておく必要があります(https://pyo3.rs/master/#using-rust-from-python)。

[target.x86_64-apple-darwin]
rustflags = [
    "-C", "link-arg=-undefined",
    "-C", "link-arg=dynamic_lookup",
]

ビルド成功すると target/release/libstring_sum.dylib ができます。これを string_sum.so に書き換えて Python ファイルと同じディレクトリに配置します。Python からはモジュールとして利用可能です。

# main.py
import string_sum
print(string_sum.sum_as_string(2, 3))

これで Python から Rust のコードが呼べます。

# python3 main.py
5

よかったですね。

うまくいかない場合

error[E0554]: #![feature] may not be used on the stable release channel

Rust の stable build だとこうなる。nightly をインストールしてください。

error[E0554]: #![feature] may not be used on the stable release channel
    --> /Users/kisk/.cargo/registry/src/github.com-1ecc6299db9ec823/pyo3-0.6.0/src/lib.rs:1:1
    |
1 | #![feature(specialization)]
    | ^^^^^^^^^^^^^^^^^^^^^^^^^^^

error: aborting due to previous error

For more information about this error, try `rustc --explain E0554`.
The following warnings were emitted during compilation:

warning: pyo3 was unable to check rustc compatibility.
warning: Build may fail due to incompatible rustc version.

error: Could not compile `pyo3`.

note: Undefined symbols for architecture x86_64:

mac でちゃんと設定をする。

error: linking with `cc` failed: exit code: 1
    |
    = note: "cc" "-m64" "-L" (以下略)
    = note: Undefined symbols for architecture x86_64:
            "_PyDict_Size", referenced from:
                pyo3::types::dict::PyDict::len::h05c3fcd816c8b8f0 in libpyo3-d3c73d3aa8dd4f32.rlib(pyo3-d3c73d3aa8dd4f32.pyo3.1i0p2ovg-cgu.4.rcgu.o)
            "_PyObject_GetAttr", referenced from:
                pyo3::object::PyObject::getattr::_<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span>closure<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span>::h34bbd2a48bc53c7a in libpyo3-d3c73d3aa8dd4f32.rlib(pyo3-d3c73d3aa8dd4f32.pyo3.1i0p2ovg-cgu.5.rcgu.o)
                        (中略)
            "_PyObject_SetAttr", referenced from:
                _<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>L</mi><mi>T</mi></mrow><annotation encoding="application/x-tex">LT</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathdefault">L</span><span class="mord mathdefault" style="margin-right:0.13889em;">T</span></span></span></span>T<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>20</mn></mrow><annotation encoding="application/x-tex">u20</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">2</span><span class="mord">0</span></span></span></span>as<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>20</mn></mrow><annotation encoding="application/x-tex">u20</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">2</span><span class="mord">0</span></span></span></span>pyo3..objectprotocol..ObjectProtocol<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>G</mi><mi>T</mi></mrow><annotation encoding="application/x-tex">GT</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathdefault">G</span><span class="mord mathdefault" style="margin-right:0.13889em;">T</span></span></span></span>::setattr::_<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span>closure<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span>::_<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span>closure<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span>::hd327ab921c0ae35f in string_sum.2259yfgj5etko2a.rcgu.o
                _<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>L</mi><mi>T</mi></mrow><annotation encoding="application/x-tex">LT</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathdefault">L</span><span class="mord mathdefault" style="margin-right:0.13889em;">T</span></span></span></span>T<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>20</mn></mrow><annotation encoding="application/x-tex">u20</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">2</span><span class="mord">0</span></span></span></span>as<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>20</mn></mrow><annotation encoding="application/x-tex">u20</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.64444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">2</span><span class="mord">0</span></span></span></span>pyo3..objectprotocol..ObjectProtocol<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>G</mi><mi>T</mi></mrow><annotation encoding="application/x-tex">GT</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.68333em;vertical-align:0em;"></span><span class="mord mathdefault">G</span><span class="mord mathdefault" style="margin-right:0.13889em;">T</span></span></span></span>::setattr::_<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span>closure<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span>::_<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>b</mi></mrow><annotation encoding="application/x-tex">u7b</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">b</span></span></span></span>closure<span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span><span class="katex"><span class="katex-mathml"><math><semantics><mrow><mi>u</mi><mn>7</mn><mi>d</mi></mrow><annotation encoding="application/x-tex">u7d</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height:0.69444em;vertical-align:0em;"></span><span class="mord mathdefault">u</span><span class="mord">7</span><span class="mord mathdefault">d</span></span></span></span>::h2fa32525c10b1ed0 in libpyo3-d3c73d3aa8dd4f32.rlib(pyo3-d3c73d3aa8dd4f32.pyo3.1i0p2ovg-cgu.6.rcgu.o)
            ld: symbol(s) not found for architecture x86_64
            clang: error: linker command failed with exit code 1 (use -v to see invocation)

error: aborting due to previous error

ImportError: dlopen

pyo3 最新版で Python 2 は非対応です。Python3 に移行しましょう。

ImportError: dlopen(/Users/kisk/Dropbox/Projects/rust-python-bridge/python/string_sum.so, 2): Symbol not found: _PyModule_Create2
    Referenced from: /Users/kisk/Dropbox/Projects/rust-python-bridge/python/string_sum.so
    Expected in: flat namespace

いろいろな型を扱ってみる

数値

sum_string のサンプルは usize でした。f64 と u32 にしてみる。

fn sum_as_string(a: f64, b: u32) -> PyResult<String> {
    Ok((a + b as f64).to_string())
}

普通に使えます。

import string_sum
print(string_sum.sum_as_string(1.4, 2))

たとえば u32 の引数に float を渡すとどうなるでしょう。

import string_sum
print(string_sum.sum_as_string(100.0, 3.2))

ちゃんとしたわかりやすいエラーが出ます。えらい。

Traceback (most recent call last):
    File "python/main.py", line 3, in <module>
    print(string_sum.sum_as_string(100.0, 3.2))
TypeError: 'float' object cannot be interpreted as an integer

別のエラーを発生させてみましょう。Python 3 の数値型は long 長を超えられるので、巨大整数を u32 に渡してみます。

import string_sum
print(string_sum.sum_as_string(1, 100 ** 100))

OverflowError が発生します。わかりやすいですね。

Traceback (most recent call last):
    File "python/main.py", line 3, in <module>
    print(string_sum.sum_as_string(1, 100 ** 100))
OverflowError: Python int too large to convert to C long

こういうときに RuntimeError で何も言わずに死んだりしないので安心です。

文字列

普通に str で受け取ることができます。

#[pyfunction]
fn get_length(a: &str) -> PyResult<usize> {
    Ok(a.len())
}

リスト

Vec で受け取ることができます。

#[pyfunction]
fn multiply_array(a: Vec<usize>, b: usize) -> PyResult<Vec<usize>> {
    Ok(a.iter().map(|v| v * b).collect())
}

タプル

タプルとして受け取ることができます。

#[pyfunction]
fn multiply_tuple(a: (usize, usize), b: usize) -> PyResult<(usize, usize)> {
    Ok((a.0 * b, a.1 * b))
}

辞書

dict 型は PyDict として受け取ります。

#[pyfunction]
fn multiply_dict(py: Python, a: &PyDict, b: usize) -> PyResult<PyObject> {
    let x_value: &PyLong = a.get_item("x").ok_or(KeyError)?.try_into()?; // PyO3 の Error に書き換える必要がある
    let y_value: &PyLong = a.get_item("y").ok_or(KeyError)?.try_into()?;

    let new_x_value = x_value.extract::<usize>()? * b;
    let new_y_value = y_value.extract::<usize>()? * b;

    let result = PyDict::new(py);
    result.set_item("a", new_x_value)?;
    result.set_item("b", new_y_value)?;
    Ok(result.into_object(py))
}

日付時刻

PyDate として受け取ります。

#[pyfunction]
fn get_next_year_date(py: Python, date: &PyDate) -> PyResult<PyObject> {
    Ok(PyDate::new(py, date.get_year() + 1, date.get_month(), date.get_day())?.into_object(py))
}

こういう書き方もできるようです。

#[pyfunction]
fn get_next_year_date(py: Python<'_>, date: &PyDate) -> PyResult<Py<PyDate>> {
    PyDate::new(py, date.get_year() + 1, date.get_month(), date.get_day())
}

どういう違いがあるかは正直よくわかっていません。ドキュメントが拡充されることを祈ります。

メソッド

Python からメソッド(lambda)を受け取り、Rust 側からコールする例です。

#[pyfunction]
fn call_lambda(py: Python, lambda: PyObject) -> PyResult<usize> {
    let result = lambda.call0(py);
    result?.extract::<usize>(py)
}

#[pyfunction]
fn call_lambda_with_arg(py: Python, lambda: PyObject) -> PyResult<usize> {
    let result = lambda.call1(py, PyTuple::new(py, [123].into_iter()));
    result?.extract::<usize>(py)
}

引数がない場合は call0 , ある場合は call1 、キーワード引数もある場合は call2 を利用します。

クラス

#[pyclass]
struct SomeStructInRust {
    pub number: i32,
    pub text: String
}


#[pymethods]
impl SomeStructInRust {
    #[new]
    fn new(obj: &PyRawObject, num: i32) {
        obj.init({
            SomeStructInRust {
                number: num,
                text: "default value".to_string()
            }
        });
    }

    #[getter]
    fn number(&self) -> PyResult<i32> {
        Ok(self.number)
    }

    #[getter]
    fn text(&self) -> PyResult<&String> {
        Ok(&self.text)
    }

    fn concat(&self) -> PyResult<String> {
        Ok(format!("{}-{}", &self.text, self.number))
    }
}

クラスは add_class を用いてモジュールに登録します。

#[pymodule]
fn string_sum(_py: Python, m: &PyModule) -> PyResult<()> {
    // 省略
    m.add_class::<SomeStructInRust>()?;
    Ok(())
}

Python 側からクラスとして扱うことができます。

print(string_sum.SomeStructInRust(1))

Rust-Python 間でクラスを渡したり受けっとたりできます。

#[pyfunction]
fn get_rust_class(py: Python) -> PyResult<Py<SomeStructInRust>> {
    let result = SomeStructInRust {
        number: 1,
        text: "abc".to_string()
    };
    Py::new(py, result)
}


#[pyfunction]
fn pass_python_class(py: Python, a: PyObject) -> PyResult<usize> {
    let tm = a.getattr(py, "number")?;
    let x_value: &PyLong = tm.extract(py)?;
    Ok(x_value.extract::<usize>()? * 10)
}

まとめ

PyO3 で Python から Rust のモジュールを利用する方法について説明しました。ドキュメントが不足している部分もあり特に Py PyObject まわりの役割については不明瞭なままですが、syntax はわかりやすいですし、型違いの場合のエラーハンドリングなどしっかりした挙動になっているなという印象です。