Compare commits
51 Commits
master
...
developmen
| Author | SHA1 | Date |
|---|---|---|
|
|
3734d8a69f | |
|
|
c07fbd48fc | |
|
|
bbf6c2d8a3 | |
|
|
ff65af52d2 | |
|
|
bad40a5c0b | |
|
|
353a740508 | |
|
|
2e7039750d | |
|
|
ac96fb5603 | |
|
|
adc95d1c5f | |
|
|
dfd62c6f09 | |
|
|
6bef8f639e | |
|
|
ee78d43549 | |
|
|
adf37ddf34 | |
|
|
18a81bd94c | |
|
|
e95ad219a5 | |
|
|
dd7969ed1f | |
|
|
50349d2c18 | |
|
|
8b4085edc5 | |
|
|
bcef8723a8 | |
|
|
01247e334d | |
|
|
25439e932a | |
|
|
1b08dfef51 | |
|
|
8e50e93a55 | |
|
|
21ae3b2efc | |
|
|
d0a0322083 | |
|
|
c3ad71cf13 | |
|
|
0dc53013ca | |
|
|
fc42ebbd77 | |
|
|
e2ea6ddbce | |
|
|
7747ac472d | |
|
|
23991efaff | |
|
|
cfc745799b | |
|
|
f7b6a73a7d | |
|
|
36e36a4736 | |
|
|
dfd88d152c | |
|
|
29407b7a3a | |
|
|
9711c7ce31 | |
|
|
f77abcb2aa | |
|
|
eb92e5e872 | |
|
|
e017a34865 | |
|
|
cc22df130c | |
|
|
cab71cec9d | |
|
|
d660d13956 | |
|
|
8d1c21325e | |
|
|
dd99b14a25 | |
|
|
01ecd92103 | |
|
|
a9af415767 | |
|
|
b619993344 | |
|
|
d66ca4a332 | |
|
|
6fb600b38f | |
|
|
804114b294 |
|
|
@ -1,7 +1,6 @@
|
|||
*.iml
|
||||
.gradle
|
||||
/local.properties
|
||||
.idea/
|
||||
.DS_Store
|
||||
/build
|
||||
/captures
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
[submodule "app/src/main/cpp/src/libs/sentencepiece"]
|
||||
path = app/src/main/cpp/src/libs/sentencepiece
|
||||
url = https://github.com/niedev/sentencepiece
|
||||
[submodule "app/src/main/cpp/src/libs/bergamot"]
|
||||
path = app/src/main/cpp/src/libs/bergamot
|
||||
url = https://github.com/niedev/bergamot-translator
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
# Default ignored files
|
||||
/shelf/
|
||||
/usage.statistics.xml
|
||||
/workspace.xml
|
||||
#extra ignored files
|
||||
/caches
|
||||
/caches/*
|
||||
/gradle.xml
|
||||
/dataSources.ids
|
||||
/datasources.xml
|
||||
/modules.xml
|
||||
/dictionaries
|
||||
/libraries
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AndroidProjectSystem">
|
||||
<option name="providerId" value="com.android.tools.idea.GradleProjectSystem" />
|
||||
</component>
|
||||
</project>
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,6 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="CompilerConfiguration">
|
||||
<bytecodeTargetLevel target="23" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="deploymentTargetSelector">
|
||||
<selectionStates>
|
||||
<SelectionState runConfigName="app">
|
||||
<option name="selectionMode" value="DROPDOWN" />
|
||||
</SelectionState>
|
||||
</selectionStates>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="DeviceTable">
|
||||
<option name="columnSorters">
|
||||
<list>
|
||||
<ColumnSorterState>
|
||||
<option name="column" value="Name" />
|
||||
<option name="order" value="ASCENDING" />
|
||||
</ColumnSorterState>
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="GradleMigrationSettings" migrationVersion="1" />
|
||||
<component name="GradleSettings">
|
||||
<option name="linkedExternalProjectsSettings">
|
||||
<GradleProjectSettings>
|
||||
<option name="testRunner" value="CHOOSE_PER_TEST" />
|
||||
<option name="externalProjectPath" value="$PROJECT_DIR$" />
|
||||
<option name="gradleJvm" value="corretto-23" />
|
||||
<option name="modules">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
<option value="$PROJECT_DIR$/app" />
|
||||
</set>
|
||||
</option>
|
||||
</GradleProjectSettings>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Kotlin2JsCompilerArguments">
|
||||
<option name="moduleKind" value="plain" />
|
||||
</component>
|
||||
<component name="Kotlin2JvmCompilerArguments">
|
||||
<option name="jvmTarget" value="1.8" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="MarkdownSettings">
|
||||
<option name="previewPanelProviderInfo">
|
||||
<ProviderInfo name="Compose (experimental)" className="com.intellij.markdown.compose.preview.ComposePanelProvider" />
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectMigrations">
|
||||
<option name="MigrateToGradleLocalJavaHome">
|
||||
<set>
|
||||
<option value="$PROJECT_DIR$" />
|
||||
</set>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ExternalStorageConfigurationManager" enabled="true" />
|
||||
<component name="ProjectRootManager" version="2" languageLevel="JDK_23" default="true" project-jdk-name="corretto-23" project-jdk-type="JavaSDK">
|
||||
<output url="file://$PROJECT_DIR$/build/classes" />
|
||||
</component>
|
||||
<component name="ProjectType">
|
||||
<option name="id" value="Android" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="RunConfigurationProducerService">
|
||||
<option name="ignoredProducers">
|
||||
<set>
|
||||
<option value="com.intellij.execution.junit.AbstractAllInDirectoryConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.AllInPackageConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.PatternConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.TestInClassConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.UniqueIdConfigurationProducer" />
|
||||
<option value="com.intellij.execution.junit.testDiscovery.JUnitTestDiscoveryConfigurationProducer" />
|
||||
<option value="org.jetbrains.kotlin.idea.junit.KotlinJUnitRunConfigurationProducer" />
|
||||
<option value="org.jetbrains.kotlin.idea.junit.KotlinPatternConfigurationProducer" />
|
||||
</set>
|
||||
</option>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/app/src/main/cpp/src/libs/bergamot" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/app/src/main/cpp/src/libs/bergamot/3rd_party/marian-dev" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/app/src/main/cpp/src/libs/bergamot/3rd_party/ssplit-cpp" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/app/src/main/cpp/src/libs/sentencepiece" vcs="Git" />
|
||||
</component>
|
||||
<component name="VcsProjectSettings">
|
||||
<option name="detectVcsMappingsAutomatically" value="false" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
<component name="ProjectRunConfigurationManager">
|
||||
<configuration default="false" name="app" type="AndroidRunConfigurationType" factoryName="Android App" activateToolWindowBeforeRun="false">
|
||||
<module name="RTranslator.app" />
|
||||
<option name="ANDROID_RUN_CONFIGURATION_SCHEMA_VERSION" value="1" />
|
||||
<option name="DEPLOY" value="true" />
|
||||
<option name="DEPLOY_APK_FROM_BUNDLE" value="false" />
|
||||
<option name="DEPLOY_AS_INSTANT" value="false" />
|
||||
<option name="ARTIFACT_NAME" value="" />
|
||||
<option name="PM_INSTALL_OPTIONS" value="" />
|
||||
<option name="ALL_USERS" value="false" />
|
||||
<option name="ALWAYS_INSTALL_WITH_PM" value="false" />
|
||||
<option name="ALLOW_ASSUME_VERIFIED" value="false" />
|
||||
<option name="CLEAR_APP_STORAGE" value="false" />
|
||||
<option name="DYNAMIC_FEATURES_DISABLED_LIST" value="" />
|
||||
<option name="ACTIVITY_EXTRA_FLAGS" value="" />
|
||||
<option name="MODE" value="default_activity" />
|
||||
<option name="RESTORE_ENABLED" value="false" />
|
||||
<option name="RESTORE_FILE" value="" />
|
||||
<option name="RESTORE_FRESH_INSTALL_ONLY" value="false" />
|
||||
<option name="CLEAR_LOGCAT" value="false" />
|
||||
<option name="SHOW_LOGCAT_AUTOMATICALLY" value="false" />
|
||||
<option name="TARGET_SELECTION_MODE" value="DEVICE_AND_SNAPSHOT_COMBO_BOX" />
|
||||
<option name="SELECTED_CLOUD_MATRIX_CONFIGURATION_ID" value="-1" />
|
||||
<option name="SELECTED_CLOUD_MATRIX_PROJECT_ID" value="" />
|
||||
<option name="DEBUGGER_TYPE" value="Auto" />
|
||||
<Auto>
|
||||
<option name="USE_JAVA_AWARE_DEBUGGER" value="false" />
|
||||
<option name="SHOW_STATIC_VARS" value="true" />
|
||||
<option name="WORKING_DIR" value="" />
|
||||
<option name="TARGET_LOGGING_CHANNELS" value="lldb process:gdb-remote packets" />
|
||||
<option name="SHOW_OPTIMIZED_WARNING" value="true" />
|
||||
<option name="ATTACH_ON_WAIT_FOR_DEBUGGER" value="false" />
|
||||
<option name="DEBUG_SANDBOX_SDK" value="false" />
|
||||
</Auto>
|
||||
<Hybrid>
|
||||
<option name="USE_JAVA_AWARE_DEBUGGER" value="false" />
|
||||
<option name="SHOW_STATIC_VARS" value="true" />
|
||||
<option name="WORKING_DIR" value="" />
|
||||
<option name="TARGET_LOGGING_CHANNELS" value="lldb process:gdb-remote packets" />
|
||||
<option name="SHOW_OPTIMIZED_WARNING" value="true" />
|
||||
<option name="ATTACH_ON_WAIT_FOR_DEBUGGER" value="false" />
|
||||
<option name="DEBUG_SANDBOX_SDK" value="false" />
|
||||
</Hybrid>
|
||||
<Java>
|
||||
<option name="ATTACH_ON_WAIT_FOR_DEBUGGER" value="false" />
|
||||
<option name="DEBUG_SANDBOX_SDK" value="false" />
|
||||
</Java>
|
||||
<Native>
|
||||
<option name="USE_JAVA_AWARE_DEBUGGER" value="false" />
|
||||
<option name="SHOW_STATIC_VARS" value="true" />
|
||||
<option name="WORKING_DIR" value="" />
|
||||
<option name="TARGET_LOGGING_CHANNELS" value="lldb process:gdb-remote packets" />
|
||||
<option name="SHOW_OPTIMIZED_WARNING" value="true" />
|
||||
<option name="ATTACH_ON_WAIT_FOR_DEBUGGER" value="false" />
|
||||
<option name="DEBUG_SANDBOX_SDK" value="false" />
|
||||
</Native>
|
||||
<Profilers>
|
||||
<option name="ADVANCED_PROFILING_ENABLED" value="false" />
|
||||
<option name="STARTUP_PROFILING_ENABLED" value="false" />
|
||||
<option name="STARTUP_CPU_PROFILING_ENABLED" value="false" />
|
||||
<option name="STARTUP_CPU_PROFILING_CONFIGURATION_NAME" value="Java/Kotlin Method Sample (legacy)" />
|
||||
<option name="STARTUP_NATIVE_MEMORY_PROFILING_ENABLED" value="false" />
|
||||
<option name="NATIVE_MEMORY_SAMPLE_RATE_BYTES" value="2048" />
|
||||
</Profilers>
|
||||
<option name="DEEP_LINK" value="" />
|
||||
<option name="ACTIVITY" value="" />
|
||||
<option name="ACTIVITY_CLASS" value="" />
|
||||
<option name="SEARCH_ACTIVITY_IN_GLOBAL_SCOPE" value="false" />
|
||||
<option name="SKIP_ACTIVITY_VALIDATION" value="false" />
|
||||
<method v="2">
|
||||
<option name="Android.Gradle.BeforeRunTask" enabled="true" />
|
||||
</method>
|
||||
</configuration>
|
||||
</component>
|
||||
21
README.md
21
README.md
|
|
@ -106,7 +106,9 @@ To speak, RTranslator uses the system TTS of your phone, so the quality of the l
|
|||
|
||||
The supported languages seen above are all compatible with <a href="https://play.google.com/store/apps/details?id=com.google.android.tts&pcampaignid=web_share">Google TTS</a>, which is the recommended TTS (although you can use the TTS you want).
|
||||
|
||||
To change the system TTS (and therefore the TTS used by RTranslator), download the TTS you want to use from the Play Store, or from the source you prefer, and open RTranslator, then open its settings (top right) and, in the "Output" section, click on "Text to Speech", at this point the system settings will open in the section where you can select the preferred system TTS engine (among those installed), at this point, if you have changed the preferred engine, restart RTranslator to apply the changes.
|
||||
To change the system TTS (and therefore the TTS used by RTranslator), download the TTS you want to use from the Play Store, or from the source you prefer, and open RTranslator, then open its settings (top right) and, in the "Output" section, click on "Text to Speech", at this point the system settings will open in the section where you can select the preferred system TTS engine (among those installed), at this point, if you have changed the preferred engine, restart RTranslator to apply the changes (close it from the recent apps and then reopen it).
|
||||
|
||||
**Note:** If after that the TTS doesn't work, you can clear the cache of RTranslator and the TTS from Android Applications settings, reboot the phone and retry.
|
||||
<br /><br />
|
||||
|
||||
<h3>Privacy</h3>
|
||||
|
|
@ -167,6 +169,23 @@ So, if you like the app and want to say thank you and support the project, you c
|
|||
In case you will donate, or just live a star, thank you :heart:
|
||||
<br /><br />
|
||||
|
||||
<h3>Connected external projects</h3>
|
||||
|
||||
Take a look at these awesome projects that use RTranslator code:
|
||||
|
||||
[**WhisperIMEplus**](https://github.com/woheller69/whisperIMEplus)
|
||||
|
||||
[**WhisperJET**](https://github.com/eix128/WhisperJET)
|
||||
<br /><br />
|
||||
|
||||
|
||||
<h3>Contributions</h3>
|
||||
|
||||
If you want to contribute to this project, first of all, thank you 🚀
|
||||
|
||||
If you don't know where to start, go check the [to-do list](https://github.com/niedev/RTranslator/blob/v2.00/TODO_LIST.md), and in any case, before starting, read the [contribution guidelines](https://github.com/niedev/RTranslator/blob/v2.00/CONTRIBUTING.md).
|
||||
<br /><br />
|
||||
|
||||
<h3>Bugs and problems</h3>
|
||||
I remind you that the app is still in beta. The bugs found are the following:
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ If you want to contribute to RTranslator but you don't know where to start, here
|
|||
| Mic manual mode setting to choose whether to start with automatic or manual mic mode in WalkieTalkie or Conversation mode (separate settings for the two modes). | Not started |
|
||||
| Button to stop and eventually repeat the TTS in the messages of WalkieTalkie and Conversation mode. | Not started |
|
||||
| A new section in the settings that shows the version of the app, plus link to the releases of the app that also show if there is a new version. | Not started |
|
||||
| A new option in the settings to show, in WalkieTalkie and Conversation modes, the original transcription of the message beyond the translation. | Not started |
|
||||
| A new option in the settings to show, in WalkieTalkie and Conversation modes, the original transcription of the message beyond the translation. | [Done](https://github.com/niedev/RTranslator/pull/128) |
|
||||
|
||||
<br/><br/>
|
||||
Here is also a list of more difficult things to do if you want (I don't expect to do one of these all by yourself, even pull requests that lay down the foundations or a one that do a small contribution about these features are higly appreciated):
|
||||
|
|
|
|||
|
|
@ -25,16 +25,41 @@ ext {
|
|||
|
||||
android {
|
||||
namespace 'nie.translator.rtranslator'
|
||||
compileSdkVersion 33 //33
|
||||
//compileSdkVersion 33 //33
|
||||
ndkVersion "28.0.12674087"
|
||||
defaultConfig {
|
||||
applicationId "nie.translator.rtranslator"
|
||||
targetSdkVersion 32 //32
|
||||
targetSdkVersion 36 //32
|
||||
compileSdk 36
|
||||
versionCode 24
|
||||
versionName '2.1.4'
|
||||
minSdkVersion 24
|
||||
minSdkVersion 28
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
cppFlags ''
|
||||
arguments '-DSPM_ENABLE_SHARED=OFF', //sentencepiece flags
|
||||
'-DSPM_BUILD_TEST=OFF',
|
||||
'-DSPM_ENABLE_TCMALLOC=OFF',
|
||||
//bergamot flags
|
||||
'-DANDROID_ABI=arm64-v8a',
|
||||
'-DANDROID_PLATFORM=android-28',
|
||||
'-DANDROID_STL=c++_static',
|
||||
"-DANDROID_PIE=ON",
|
||||
"-DANDROID_CPP_FEATURES=exceptions",
|
||||
'-DCMAKE_BUILD_TYPE=Release',
|
||||
'-DBUILD_SHARED_LIBS=OFF',
|
||||
'-DUSE_STATIC_LIBS=ON',
|
||||
'-DCOMPILE_CUDA=OFF',
|
||||
'-DUSE_WASM_COMPATIBLE_SOURCE=OFF',
|
||||
'-DTHREADS_PREFER_PTHREAD_FLAG=ON',
|
||||
'-DUSE_TCMALLOC=OFF',
|
||||
'-DSSPLIT_USE_INTERNAL_PCRE2:BOOL=ON',
|
||||
'-DUSE_PATHIECPP=ON',
|
||||
'-DBUILD_ARCH=armv8-a',
|
||||
'-GNinja',
|
||||
//'-DCMAKE_C_FLAGS="-O3 -fPIC -march=armv8-a -mtune=generic -ffast-math -funroll-loops -fdata-sections -ffunction-sections"',
|
||||
//'-DCMAKE_CXX_FLAGS="-O3 -fPIC -march=armv8-a -mtune=generic -ffast-math -funroll-loops -fdata-sections -ffunction-sections"',
|
||||
'-DCMAKE_CXX_FLAGS="-Wno-enum-constexpr-conversion"'
|
||||
cppFlags '-std=c++17'
|
||||
abiFilters 'arm64-v8a'
|
||||
}
|
||||
}
|
||||
|
|
@ -78,7 +103,7 @@ android {
|
|||
}
|
||||
externalNativeBuild {
|
||||
cmake {
|
||||
path file('src/main/cpp/CMakeLists.txt')
|
||||
path file('src/main/cpp/src/CMakeLists.txt')
|
||||
version '3.22.1'
|
||||
}
|
||||
}
|
||||
|
|
@ -86,7 +111,6 @@ android {
|
|||
|
||||
dependencies {
|
||||
//implementation fileTree(include: ['*.jar'], dir: 'libs')
|
||||
implementation 'com.github.okitcom:SwitchButton:1.4.5'
|
||||
// Support libraries
|
||||
implementation "com.google.android.material:material:1.9.0" //1.9.0
|
||||
implementation "androidx.cardview:cardview:1.0.0"
|
||||
|
|
@ -95,6 +119,8 @@ dependencies {
|
|||
implementation "androidx.preference:preference:1.1.0-alpha02" //prima era 1.1.0-alpha02
|
||||
implementation "androidx.core:core-splashscreen:1.0.1"
|
||||
implementation "androidx.lifecycle:lifecycle-extensions:2.2.0"
|
||||
//Download library
|
||||
implementation 'com.github.amitshekhariitbhu:PRDownloader:1.0.2'
|
||||
//implementation 'androidx.core:core-ktx:1.10.0'
|
||||
implementation 'androidx.work:work-runtime:2.7.1'
|
||||
implementation 'androidx.exifinterface:exifinterface:1.3.7'
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@
|
|||
</intent>
|
||||
</queries>
|
||||
|
||||
<uses-permission android:name="android.permission.MANAGE_EXTERNAL_STORAGE"/> <!-- todo: rimuere questo permesso -->
|
||||
|
||||
<uses-permission android:name="android.permission.POST_NOTIFICATIONS"/>
|
||||
<uses-permission android:name="android.permission.ACCESS_NETWORK_STATE" />
|
||||
<uses-permission android:name="android.permission.INTERNET" />
|
||||
<uses-permission android:name="android.permission.SEND_DOWNLOAD_COMPLETED_INTENTS" />
|
||||
|
|
@ -16,6 +19,8 @@
|
|||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
<uses-permission android:name="android.permission.MODIFY_AUDIO_SETTINGS" />
|
||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE" />
|
||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_MICROPHONE" />
|
||||
<uses-permission android:name="android.permission.FOREGROUND_SERVICE_DATA_SYNC" />
|
||||
|
||||
<!--<uses-permission android:name="android.permission.ACCESS_WIFI_STATE" />
|
||||
<uses-permission android:name="android.permission.CHANGE_WIFI_STATE" />
|
||||
|
|
@ -26,8 +31,12 @@
|
|||
<uses-permission android:name="android.permission.BLUETOOTH_SCAN" android:usesPermissionFlags="neverForLocation"/>
|
||||
<uses-permission android:name="android.permission.BLUETOOTH_ADVERTISE" />
|
||||
<uses-permission android:name="android.permission.BLUETOOTH_CONNECT" />
|
||||
<uses-permission android:name="android.permission.ACCESS_FINE_LOCATION" />
|
||||
<uses-permission android:name="android.permission.ACCESS_COARSE_LOCATION" />
|
||||
<uses-permission
|
||||
android:name="android.permission.ACCESS_FINE_LOCATION"
|
||||
android:maxSdkVersion="30"/>
|
||||
<uses-permission
|
||||
android:name="android.permission.ACCESS_COARSE_LOCATION"
|
||||
android:maxSdkVersion="30"/>
|
||||
|
||||
<!--<uses-permission android:name="android.permission.BLUETOOTH" />
|
||||
<uses-permission android:name="android.permission.BLUETOOTH_ADMIN" />
|
||||
|
|
@ -51,7 +60,8 @@
|
|||
android:largeHeap="true"
|
||||
android:supportsRtl="true"
|
||||
android:theme="@style/Theme.Speech"
|
||||
tools:ignore="GoogleAppIndexingWarning">
|
||||
android:enableOnBackInvokedCallback="false"
|
||||
tools:ignore="GoogleAppIndexingWarning,NewApi">
|
||||
<activity
|
||||
android:name="nie.translator.rtranslator.LoadingActivity"
|
||||
android:exported="true"
|
||||
|
|
@ -95,6 +105,8 @@
|
|||
<service android:name="nie.translator.rtranslator.voice_translation._walkie_talkie_mode._walkie_talkie.WalkieTalkieService"
|
||||
android:foregroundServiceType="microphone"/>
|
||||
<service android:name="nie.translator.rtranslator.GeneralService" />
|
||||
<service android:name=".downloader2.DownloaderService"
|
||||
android:foregroundServiceType="dataSync" />
|
||||
|
||||
<provider
|
||||
android:name="androidx.core.content.FileProvider"
|
||||
|
|
@ -108,10 +120,10 @@
|
|||
<uses-library
|
||||
android:name="com.sec.android.app.multiwindow"
|
||||
android:required="false" />
|
||||
<receiver android:name=".access.DownloadReceiver" android:exported="true">
|
||||
<!--<receiver android:name=".access.DownloadReceiver" android:exported="true">
|
||||
<intent-filter>
|
||||
<action android:name="android.intent.action.DOWNLOAD_COMPLETE" />
|
||||
</intent-filter>
|
||||
</receiver>
|
||||
</receiver>-->
|
||||
</application>
|
||||
</manifest>
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
Makefile
|
||||
Makefile.in
|
||||
/ar-lib
|
||||
/mdate-sh
|
||||
/py-compile
|
||||
/test-driver
|
||||
/ylwrap
|
||||
/build
|
||||
|
||||
/autom4te.cache
|
||||
/autoscan.log
|
||||
/autoscan-*.log
|
||||
/aclocal.m4
|
||||
/compile
|
||||
/config.guess
|
||||
/config.sub
|
||||
/configure
|
||||
/configure.scan
|
||||
/depcomp
|
||||
/install-sh
|
||||
/missing
|
||||
/stamp-h1
|
||||
/libtool
|
||||
/config.h
|
||||
/config.status
|
||||
/autogen.sh
|
||||
/ltmain.sh
|
||||
|
||||
CMakeFiles
|
||||
CMakeCache.txt
|
||||
config.h
|
||||
sentencepiece.pc
|
||||
CPackConfig.cmake
|
||||
CTestTestfile.cmake
|
||||
CPackSourceConfig.cmake
|
||||
DartConfiguration.tcl
|
||||
|
||||
*.o
|
||||
*.lo
|
||||
*.a
|
||||
*.la
|
||||
*.pyc
|
||||
|
||||
.libs
|
||||
.deps
|
||||
|
||||
*.m4
|
||||
*.log
|
||||
*.trs
|
||||
|
||||
compile_charsmap
|
||||
|
||||
spm_decode
|
||||
spm_encode
|
||||
spm_export_vocab
|
||||
spm_train
|
||||
spm_normalize
|
||||
spm_test
|
||||
|
||||
.DS_Store
|
||||
*.egg-info/
|
||||
dist/
|
||||
*.swp
|
||||
*.swo
|
||||
*.pyc
|
||||
|
||||
m.model
|
||||
m.vocab
|
||||
|
||||
cmake_install.cmake
|
||||
libsentencepiece.so*
|
||||
libsentencepiece_train.so*
|
||||
python/bundled
|
||||
_sentencepiece.*.so
|
||||
third_party/abseil-cpp
|
||||
|
||||
python/sentencepiece
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
# Copyright 2018 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.!
|
||||
|
||||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
|
||||
file(STRINGS "VERSION.txt" SPM_VERSION)
|
||||
message(STATUS "VERSION: ${SPM_VERSION}")
|
||||
|
||||
if(POLICY CMP0091)
|
||||
cmake_policy(SET CMP0091 NEW)
|
||||
endif()
|
||||
|
||||
project(sentencepiece VERSION ${SPM_VERSION} LANGUAGES C CXX)
|
||||
|
||||
option(SPM_ENABLE_NFKC_COMPILE "Enables NFKC compile" OFF)
|
||||
option(SPM_ENABLE_SHARED "Builds shared libaries in addition to static libraries." ON)
|
||||
option(SPM_BUILD_TEST "Builds test binaries." OFF)
|
||||
option(SPM_COVERAGE "Runs gcov to test coverage." OFF)
|
||||
option(SPM_ENABLE_TENSORFLOW_SHARED "Makes a tensorflow compatible shared file." OFF)
|
||||
option(SPM_ENABLE_TCMALLOC "Enable TCMalloc if available." ON)
|
||||
option(SPM_TCMALLOC_STATIC "Link static library of TCMALLOC." OFF)
|
||||
option(SPM_NO_THREADLOCAL "Disable thread_local operator" OFF)
|
||||
option(SPM_ENABLE_MSVC_MT_BUILD, "Use /MT flag in MSVC build" OFF)
|
||||
option(SPM_CROSS_SYSTEM_PROCESSOR, "Override system processor" "")
|
||||
|
||||
set(SPM_PROTOBUF_PROVIDER "internal" CACHE STRING "Provider of protobuf library")
|
||||
set_property(CACHE SPM_PROTOBUF_PROVIDER PROPERTY STRINGS "internal" "package")
|
||||
set(SPM_ABSL_PROVIDER "internal" CACHE STRING "Provider of absl library")
|
||||
set_property(CACHE SPM_ABSL_PROVIDER PROPERTY STRINGS "internal" "module" "package")
|
||||
|
||||
if (SPM_CROSS_SYSTEM_PROCESSOR)
|
||||
set(CMAKE_SYSTEM_PROCESSOR ${SPM_CROSS_SYSTEM_PROCESSOR})
|
||||
endif()
|
||||
|
||||
# Disable shared build on windows
|
||||
if(WIN32)
|
||||
set(SPM_ENABLE_SHARED OFF)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
if((CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 10.0) OR
|
||||
(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 8.0))
|
||||
string(APPEND CMAKE_CXX_FLAGS " -fmacro-prefix-map=${CMAKE_SOURCE_DIR}/=''")
|
||||
endif()
|
||||
|
||||
if (UNIX)
|
||||
include(GNUInstallDirs)
|
||||
set(prefix ${CMAKE_INSTALL_PREFIX})
|
||||
set(exec_prefix "\${prefix}")
|
||||
set(libdir "\${exec_prefix}/${CMAKE_INSTALL_LIBDIR}")
|
||||
set(includedir "\${prefix}/${CMAKE_INSTALL_INCLUDEDIR}")
|
||||
else()
|
||||
set(prefix ${CMAKE_INSTALL_PREFIX})
|
||||
set(exec_prefix "\${prefix}")
|
||||
set(libdir "\${exec_prefix}/lib")
|
||||
set(includedir "\${prefix}/include")
|
||||
endif()
|
||||
set(GNUCXX_STD_SUPPORT_VERSION "4.3")
|
||||
|
||||
if(${CMAKE_SYSTEM_NAME} STREQUAL "FreeBSD")
|
||||
add_definitions(-D_FREEBSD)
|
||||
endif()
|
||||
|
||||
if (SPM_USE_BUILTIN_PROTOBUF)
|
||||
set(libprotobuf_lite "")
|
||||
else()
|
||||
set(libprotobuf_lite "protobuf-lite")
|
||||
endif()
|
||||
|
||||
if (MSVC)
|
||||
add_definitions("/wd4267 /wd4244 /wd4305 /Zc:strictStrings /utf-8")
|
||||
if (SPM_ENABLE_MSVC_MT_BUILD)
|
||||
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
|
||||
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_MINSIZEREL ${CMAKE_CXX_FLAGS_MINSIZEREL})
|
||||
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
|
||||
string(REPLACE "/MD" "/MT" CMAKE_CXX_FLAGS_RELWITHDEBINFO ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (APPLE)
|
||||
set(CMAKE_MACOSX_RPATH ON)
|
||||
set(CMAKE_SKIP_BUILD_RPATH FALSE)
|
||||
set(CMAKE_BUILD_WITH_INSTALL_RPATH FALSE)
|
||||
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
||||
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH TRUE)
|
||||
list(FIND CMAKE_PLATFORM_IMPLICIT_LINK_DIRECTORIES "${CMAKE_INSTALL_PREFIX}/lib" isSystemDir)
|
||||
if ("${isSystemDir}" STREQUAL "-1")
|
||||
set(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_PREFIX}/lib")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CMAKE_INSTALL_BINDIR)
|
||||
set(CMAKE_INSTALL_BINDIR bin)
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CMAKE_INSTALL_LIBDIR)
|
||||
set(CMAKE_INSTALL_LIBDIR lib)
|
||||
endif()
|
||||
|
||||
if (NOT DEFINED CMAKE_INSTALL_INCDIR)
|
||||
set(CMAKE_INSTALL_INCDIR include)
|
||||
endif()
|
||||
|
||||
# SPDX-License-Identifier: (MIT OR CC0-1.0)
|
||||
# Copyright 2020 Jan Tojnar
|
||||
# https://github.com/jtojnar/cmake-snips
|
||||
#
|
||||
# Modelled after Python’s os.path.join
|
||||
# https://docs.python.org/3.7/library/os.path.html#os.path.join
|
||||
# Windows not supported
|
||||
function(join_paths joined_path first_path_segment)
|
||||
set(temp_path "${first_path_segment}")
|
||||
foreach(current_segment IN LISTS ARGN)
|
||||
if(NOT ("${current_segment}" STREQUAL ""))
|
||||
if(IS_ABSOLUTE "${current_segment}")
|
||||
set(temp_path "${current_segment}")
|
||||
else()
|
||||
set(temp_path "${temp_path}/${current_segment}")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
set(${joined_path} "${temp_path}" PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
join_paths(libdir_for_pc_file "\${exec_prefix}" "${CMAKE_INSTALL_LIBDIR}")
|
||||
join_paths(includedir_for_pc_file "\${prefix}" "${CMAKE_INSTALL_INCLUDEDIR}")
|
||||
|
||||
configure_file("${PROJECT_SOURCE_DIR}/config.h.in" "config.h")
|
||||
configure_file("${PROJECT_SOURCE_DIR}/sentencepiece.pc.in" "sentencepiece.pc" @ONLY)
|
||||
|
||||
if (NOT MSVC)
|
||||
# suppress warning for C++11 features.
|
||||
# add_definitions("-Wno-deprecated-declarations -Wno-deprecated-enum-enum-conversion")
|
||||
install(FILES "${CMAKE_CURRENT_BINARY_DIR}/sentencepiece.pc" DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig)
|
||||
endif()
|
||||
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${PROJECT_BINARY_DIR})
|
||||
|
||||
if (SPM_BUILD_TEST)
|
||||
enable_testing()
|
||||
endif()
|
||||
|
||||
if (SPM_ABSL_PROVIDER STREQUAL "internal")
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl)
|
||||
elseif (SPM_ABSL_PROVIDER STREQUAL "module")
|
||||
include(FetchContent)
|
||||
FetchContent_Populate(abseil-cpp
|
||||
GIT_REPOSITORY https://github.com/abseil/abseil-cpp.git
|
||||
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/third_party/abseil-cpp
|
||||
GIT_PROGRESS TRUE)
|
||||
add_subdirectory(third_party/abseil-cpp)
|
||||
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl.org)
|
||||
file(RENAME ${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl ${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl.org)
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/third_party/abseil-cpp/absl
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl)
|
||||
endif()
|
||||
elseif (SPM_ABSL_PROVIDER STREQUAL "package")
|
||||
find_package(absl REQUIRED)
|
||||
get_target_property(ABSL_INCLUDE_DIRS absl::base INTERFACE_INCLUDE_DIRECTORIES)
|
||||
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl.org)
|
||||
file(RENAME ${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl ${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl.org)
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink
|
||||
${ABSL_INCLUDE_DIRS}/absl ${CMAKE_CURRENT_SOURCE_DIR}/third_party/absl)
|
||||
endif()
|
||||
include_directories(${ABSL_INCLUDE_DIRS})
|
||||
endif()
|
||||
|
||||
add_subdirectory(src)
|
||||
add_subdirectory(third_party)
|
||||
|
||||
set(CPACK_SOURCE_GENERATOR "TXZ")
|
||||
set(CPACK_GENERATOR "7Z")
|
||||
set(CPACK_PACKAGE_VERSION "${SPM_VERSION}")
|
||||
set(CPACK_STRIP_FILES TRUE)
|
||||
set(CPACK_RESOURCE_FILE_LICENSE "${PROJECT_SOURCE_DIR}/LICENSE")
|
||||
set(CPACK_RESOURCE_FILE_README "${PROJECT_SOURCE_DIR}/README.md")
|
||||
set(CPACK_PACKAGE_CONTACT "taku@google.com")
|
||||
set(CPACK_DEBIAN_PACKAGE_MAINTAINER "Taku Kudo")
|
||||
set(CPACK_SOURCE_IGNORE_FILES "/build/;/.git/;/dist/;/sdist/;~$;${CPACK_SOURCE_IGNORE_FILES}")
|
||||
include(CPack)
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
Want to contribute? Great! First, read this page (including the small print at the end).
|
||||
|
||||
### Before you contribute
|
||||
Before we can use your code, you must sign the
|
||||
[Google Individual Contributor License Agreement](https://cla.developers.google.com/about/google-individual)
|
||||
(CLA), which you can do online. The CLA is necessary mainly because you own the
|
||||
copyright to your changes even after your contribution becomes part of our
|
||||
codebase, so we need your permission to use and distribute your code. We also
|
||||
need to be sure of various other things—for instance, that you'll tell us if you
|
||||
know that your code infringes on other people's patents. You don't have to sign
|
||||
the CLA until after you've submitted your code for review and a member has
|
||||
approved it, but you must do it before we can put your code into our codebase.
|
||||
Before you start working on a larger contribution, you should get in touch with
|
||||
us first through the issue tracker with your idea so that we can help out and
|
||||
possibly guide you. Coordinating up-front makes it much easier to avoid
|
||||
frustration later on.
|
||||
|
||||
### Code reviews
|
||||
All submissions, including submissions by project members, require review. We
|
||||
use Github pull requests for this purpose.
|
||||
|
||||
### The small print
|
||||
Contributions made by corporations are covered by a different agreement than
|
||||
the one above, the [Software Grant and Corporate Contributor License Agreement](https://cla.developers.google.com/about/google-corporate).
|
||||
|
|
@ -1,202 +0,0 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
|
@ -1,294 +0,0 @@
|
|||
# SentencePiece
|
||||
|
||||
[](https://github.com/google/sentencepiece/actions/workflows/cmake.yml)
|
||||
[](https://github.com/google/sentencepiece/actions/workflows/wheel.yml)
|
||||
[](https://github.com/google/sentencepiece/issues)
|
||||
[](https://badge.fury.io/py/sentencepiece)
|
||||
[](https://pypi.org/project/sentencepiece/)
|
||||
[](CONTRIBUTING.md)
|
||||
[](https://opensource.org/licenses/Apache-2.0)
|
||||
[](https://slsa.dev)
|
||||
|
||||
SentencePiece is an unsupervised text tokenizer and detokenizer mainly for
|
||||
Neural Network-based text generation systems where the vocabulary size
|
||||
is predetermined prior to the neural model training. SentencePiece implements
|
||||
**subword units** (e.g., **byte-pair-encoding (BPE)** [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162)]) and
|
||||
**unigram language model** [[Kudo.](https://arxiv.org/abs/1804.10959)])
|
||||
with the extension of direct training from raw sentences. SentencePiece allows us to make a purely end-to-end system that does not depend on language-specific pre/postprocessing.
|
||||
|
||||
**This is not an official Google product.**
|
||||
|
||||
## Technical highlights
|
||||
- **Purely data driven**: SentencePiece trains tokenization and detokenization
|
||||
models from sentences. Pre-tokenization ([Moses tokenizer](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl)/[MeCab](http://taku910.github.io/mecab/)/[KyTea](http://www.phontron.com/kytea/)) is not always required.
|
||||
- **Language independent**: SentencePiece treats the sentences just as sequences of Unicode characters. There is no language-dependent logic.
|
||||
- **Multiple subword algorithms**: **BPE** [[Sennrich et al.](https://www.aclweb.org/anthology/P16-1162)] and **unigram language model** [[Kudo.](https://arxiv.org/abs/1804.10959)] are supported.
|
||||
- **Subword regularization**: SentencePiece implements subword sampling for [subword regularization](https://arxiv.org/abs/1804.10959) and [BPE-dropout](https://arxiv.org/abs/1910.13267) which help to improve the robustness and accuracy of NMT models.
|
||||
- **Fast and lightweight**: Segmentation speed is around 50k sentences/sec, and memory footprint is around 6MB.
|
||||
- **Self-contained**: The same tokenization/detokenization is obtained as long as the same model file is used.
|
||||
- **Direct vocabulary id generation**: SentencePiece manages vocabulary to id mapping and can directly generate vocabulary id sequences from raw sentences.
|
||||
- **NFKC-based normalization**: SentencePiece performs NFKC-based text normalization.
|
||||
|
||||
For those unfamiliar with SentencePiece as a software/algorithm, one can read [a gentle introduction here](https://medium.com/@jacky2wong/understanding-sentencepiece-under-standing-sentence-piece-ac8da59f6b08).
|
||||
|
||||
|
||||
## Comparisons with other implementations
|
||||
|Feature|SentencePiece|[subword-nmt](https://github.com/rsennrich/subword-nmt)|[WordPiece](https://arxiv.org/pdf/1609.08144.pdf)|
|
||||
|:---|:---:|:---:|:---:|
|
||||
|Supported algorithm|BPE, unigram, char, word|BPE|BPE*|
|
||||
|OSS?|Yes|Yes|Google internal|
|
||||
|Subword regularization|[Yes](#subword-regularization-and-bpe-dropout)|No|No|
|
||||
|Python Library (pip)|[Yes](python/README.md)|No|N/A|
|
||||
|C++ Library|[Yes](doc/api.md)|No|N/A|
|
||||
|Pre-segmentation required?|[No](#whitespace-is-treated-as-a-basic-symbol)|Yes|Yes|
|
||||
|Customizable normalization (e.g., NFKC)|[Yes](doc/normalization.md)|No|N/A|
|
||||
|Direct id generation|[Yes](#end-to-end-example)|No|N/A|
|
||||
|
||||
Note that BPE algorithm used in WordPiece is slightly different from the original BPE.
|
||||
|
||||
## Overview
|
||||
### What is SentencePiece?
|
||||
SentencePiece is a re-implementation of **sub-word units**, an effective way to alleviate the open vocabulary
|
||||
problems in neural machine translation. SentencePiece supports two segmentation algorithms, **byte-pair-encoding (BPE)** [[Sennrich et al.](http://www.aclweb.org/anthology/P16-1162)] and **unigram language model** [[Kudo.](https://arxiv.org/abs/1804.10959)]. Here are the high level differences from other implementations.
|
||||
|
||||
#### The number of unique tokens is predetermined
|
||||
Neural Machine Translation models typically operate with a fixed
|
||||
vocabulary. Unlike most unsupervised word segmentation algorithms, which
|
||||
assume an infinite vocabulary, SentencePiece trains the segmentation model such
|
||||
that the final vocabulary size is fixed, e.g., 8k, 16k, or 32k.
|
||||
|
||||
Note that SentencePiece specifies the final vocabulary size for training, which is different from
|
||||
[subword-nmt](https://github.com/rsennrich/subword-nmt) that uses the number of merge operations.
|
||||
The number of merge operations is a BPE-specific parameter and not applicable to other segmentation algorithms, including unigram, word and character.
|
||||
|
||||
#### Trains from raw sentences
|
||||
Previous sub-word implementations assume that the input sentences are pre-tokenized. This constraint was required for efficient training, but makes the preprocessing complicated as we have to run language dependent tokenizers in advance.
|
||||
The implementation of SentencePiece is fast enough to train the model from raw sentences. This is useful for training the tokenizer and detokenizer for Chinese and Japanese where no explicit spaces exist between words.
|
||||
|
||||
#### Whitespace is treated as a basic symbol
|
||||
The first step of Natural Language processing is text tokenization. For
|
||||
example, a standard English tokenizer would segment the text "Hello world." into the
|
||||
following three tokens.
|
||||
|
||||
> [Hello] [World] [.]
|
||||
|
||||
One observation is that the original input and tokenized sequence are **NOT
|
||||
reversibly convertible**. For instance, the information that is no space between
|
||||
“World” and “.” is dropped from the tokenized sequence, since e.g., `Tokenize(“World.”) == Tokenize(“World .”)`
|
||||
|
||||
SentencePiece treats the input text just as a sequence of Unicode characters. Whitespace is also handled as a normal symbol. To handle the whitespace as a basic token explicitly, SentencePiece first escapes the whitespace with a meta symbol "▁" (U+2581) as follows.
|
||||
|
||||
> Hello▁World.
|
||||
|
||||
Then, this text is segmented into small pieces, for example:
|
||||
|
||||
> [Hello] [▁Wor] [ld] [.]
|
||||
|
||||
Since the whitespace is preserved in the segmented text, we can detokenize the text without any ambiguities.
|
||||
|
||||
```
|
||||
detokenized = ''.join(pieces).replace('▁', ' ')
|
||||
```
|
||||
|
||||
This feature makes it possible to perform detokenization without relying on language-specific resources.
|
||||
|
||||
Note that we cannot apply the same lossless conversions when splitting the
|
||||
sentence with standard word segmenters, since they treat the whitespace as a
|
||||
special symbol. Tokenized sequences do not preserve the necessary information to restore the original sentence.
|
||||
|
||||
* (en) Hello world. → [Hello] [World] [.] \(A space between Hello and World\)
|
||||
* (ja) こんにちは世界。 → [こんにちは] [世界] [。] \(No space between こんにちは and 世界\)
|
||||
|
||||
#### Subword regularization and BPE-dropout
|
||||
Subword regularization [[Kudo.](https://arxiv.org/abs/1804.10959)] and BPE-dropout [Provilkov et al](https://arxiv.org/abs/1910.13267) are simple regularization methods
|
||||
that virtually augment training data with on-the-fly subword sampling, which helps to improve the accuracy as well as robustness of NMT models.
|
||||
|
||||
To enable subword regularization, you would like to integrate SentencePiece library
|
||||
([C++](doc/api.md#sampling-subword-regularization)/[Python](python/README.md)) into the NMT system to sample one segmentation for each parameter update, which is different from the standard off-line data preparations. Here's the example of [Python library](python/README.md). You can find that 'New York' is segmented differently on each ``SampleEncode (C++)`` or ``encode with enable_sampling=True (Python)`` calls. The details of sampling parameters are found in [sentencepiece_processor.h](src/sentencepiece_processor.h).
|
||||
|
||||
```
|
||||
>>> import sentencepiece as spm
|
||||
>>> s = spm.SentencePieceProcessor(model_file='spm.model')
|
||||
>>> for n in range(5):
|
||||
... s.encode('New York', out_type=str, enable_sampling=True, alpha=0.1, nbest_size=-1)
|
||||
...
|
||||
['▁', 'N', 'e', 'w', '▁York']
|
||||
['▁', 'New', '▁York']
|
||||
['▁', 'New', '▁Y', 'o', 'r', 'k']
|
||||
['▁', 'New', '▁York']
|
||||
['▁', 'New', '▁York']
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Python module
|
||||
SentencePiece provides Python wrapper that supports both SentencePiece training and segmentation.
|
||||
You can install Python binary package of SentencePiece with.
|
||||
|
||||
```
|
||||
pip install sentencepiece
|
||||
```
|
||||
|
||||
For more detail, see [Python module](python/README.md)
|
||||
|
||||
### Build and install SentencePiece command line tools from C++ source
|
||||
The following tools and libraries are required to build SentencePiece:
|
||||
|
||||
* [cmake](https://cmake.org/)
|
||||
* C++11 compiler
|
||||
* [gperftools](https://github.com/gperftools/gperftools) library (optional, 10-40% performance improvement can be obtained.)
|
||||
|
||||
On Ubuntu, the build tools can be installed with apt-get:
|
||||
```
|
||||
% sudo apt-get install cmake build-essential pkg-config libgoogle-perftools-dev
|
||||
```
|
||||
|
||||
Then, you can build and install command line tools as follows.
|
||||
```
|
||||
% git clone https://github.com/google/sentencepiece.git
|
||||
% cd sentencepiece
|
||||
% mkdir build
|
||||
% cd build
|
||||
% cmake ..
|
||||
% make -j $(nproc)
|
||||
% sudo make install
|
||||
% sudo ldconfig -v
|
||||
```
|
||||
On OSX/macOS, replace the last command with `sudo update_dyld_shared_cache`
|
||||
|
||||
### Build and install using vcpkg
|
||||
|
||||
You can download and install sentencepiece using the [vcpkg](https://github.com/Microsoft/vcpkg) dependency manager:
|
||||
|
||||
git clone https://github.com/Microsoft/vcpkg.git
|
||||
cd vcpkg
|
||||
./bootstrap-vcpkg.sh
|
||||
./vcpkg integrate install
|
||||
./vcpkg install sentencepiece
|
||||
|
||||
The sentencepiece port in vcpkg is kept up to date by Microsoft team members and community contributors. If the version is out of date, please [create an issue or pull request](https://github.com/Microsoft/vcpkg) on the vcpkg repository.
|
||||
|
||||
### Download and install SentencePiece from signed released wheels
|
||||
|
||||
You can download the wheel from the [GitHub releases page](https://github.com/google/sentencepiece/releases/latest).
|
||||
We generate [SLSA3 signatures](slsa.dev) using the OpenSSF's [slsa-framework/slsa-github-generator](https://github.com/slsa-framework/slsa-github-generator) during the release process. To verify a release binary:
|
||||
1. Install the verification tool from [slsa-framework/slsa-verifier#installation](https://github.com/slsa-framework/slsa-verifier#installation).
|
||||
2. Download the provenance file `attestation.intoto.jsonl` from the [GitHub releases page](https://github.com/google/sentencepiece/releases/latest).
|
||||
3. Run the verifier:
|
||||
```shell
|
||||
slsa-verifier -artifact-path <the-wheel> -provenance attestation.intoto.jsonl -source github.com/google/sentencepiece -tag <the-tag>
|
||||
```
|
||||
|
||||
pip install wheel_file.whl
|
||||
|
||||
## Usage instructions
|
||||
### Train SentencePiece Model
|
||||
```
|
||||
% spm_train --input=<input> --model_prefix=<model_name> --vocab_size=8000 --character_coverage=1.0 --model_type=<type>
|
||||
```
|
||||
* `--input`: one-sentence-per-line **raw** corpus file. No need to run
|
||||
tokenizer, normalizer or preprocessor. By default, SentencePiece normalizes
|
||||
the input with Unicode NFKC. You can pass a comma-separated list of files.
|
||||
* `--model_prefix`: output model name prefix. `<model_name>.model` and `<model_name>.vocab` are generated.
|
||||
* `--vocab_size`: vocabulary size, e.g., 8000, 16000, or 32000
|
||||
* `--character_coverage`: amount of characters covered by the model, good defaults are: `0.9995` for languages with rich character set like Japanese or Chinese and `1.0` for other languages with small character set.
|
||||
* `--model_type`: model type. Choose from `unigram` (default), `bpe`, `char`, or `word`. The input sentence must be pretokenized when using `word` type.
|
||||
|
||||
Use `--help` flag to display all parameters for training, or see [here](doc/options.md) for an overview.
|
||||
|
||||
### Encode raw text into sentence pieces/ids
|
||||
```
|
||||
% spm_encode --model=<model_file> --output_format=piece < input > output
|
||||
% spm_encode --model=<model_file> --output_format=id < input > output
|
||||
```
|
||||
|
||||
Use `--extra_options` flag to insert the BOS/EOS markers or reverse the input sequence.
|
||||
```
|
||||
% spm_encode --extra_options=eos (add </s> only)
|
||||
% spm_encode --extra_options=bos:eos (add <s> and </s>)
|
||||
% spm_encode --extra_options=reverse:bos:eos (reverse input and add <s> and </s>)
|
||||
```
|
||||
|
||||
SentencePiece supports nbest segmentation and segmentation sampling with `--output_format=(nbest|sample)_(piece|id)` flags.
|
||||
```
|
||||
% spm_encode --model=<model_file> --output_format=sample_piece --nbest_size=-1 --alpha=0.5 < input > output
|
||||
% spm_encode --model=<model_file> --output_format=nbest_id --nbest_size=10 < input > output
|
||||
```
|
||||
|
||||
### Decode sentence pieces/ids into raw text
|
||||
```
|
||||
% spm_decode --model=<model_file> --input_format=piece < input > output
|
||||
% spm_decode --model=<model_file> --input_format=id < input > output
|
||||
```
|
||||
Use `--extra_options` flag to decode the text in reverse order.
|
||||
```
|
||||
% spm_decode --extra_options=reverse < input > output
|
||||
```
|
||||
|
||||
### End-to-End Example
|
||||
```
|
||||
% spm_train --input=data/botchan.txt --model_prefix=m --vocab_size=1000
|
||||
unigram_model_trainer.cc(494) LOG(INFO) Starts training with :
|
||||
input: "../data/botchan.txt"
|
||||
... <snip>
|
||||
unigram_model_trainer.cc(529) LOG(INFO) EM sub_iter=1 size=1100 obj=10.4973 num_tokens=37630 num_tokens/piece=34.2091
|
||||
trainer_interface.cc(272) LOG(INFO) Saving model: m.model
|
||||
trainer_interface.cc(281) LOG(INFO) Saving vocabs: m.vocab
|
||||
|
||||
% echo "I saw a girl with a telescope." | spm_encode --model=m.model
|
||||
▁I ▁saw ▁a ▁girl ▁with ▁a ▁ te le s c o pe .
|
||||
|
||||
% echo "I saw a girl with a telescope." | spm_encode --model=m.model --output_format=id
|
||||
9 459 11 939 44 11 4 142 82 8 28 21 132 6
|
||||
|
||||
% echo "9 459 11 939 44 11 4 142 82 8 28 21 132 6" | spm_decode --model=m.model --input_format=id
|
||||
I saw a girl with a telescope.
|
||||
```
|
||||
You can find that the original input sentence is restored from the vocabulary id sequence.
|
||||
|
||||
### Export vocabulary list
|
||||
```
|
||||
% spm_export_vocab --model=<model_file> --output=<output file>
|
||||
```
|
||||
```<output file>``` stores a list of vocabulary and emission log probabilities. The vocabulary id corresponds to the line number in this file.
|
||||
|
||||
### Redefine special meta tokens
|
||||
By default, SentencePiece uses Unknown (<unk>), BOS (<s>) and EOS (</s>) tokens which have the ids of 0, 1, and 2 respectively. We can redefine this mapping in the training phase as follows.
|
||||
|
||||
```
|
||||
% spm_train --bos_id=0 --eos_id=1 --unk_id=5 --input=... --model_prefix=... --character_coverage=...
|
||||
```
|
||||
When setting -1 id e.g., ```bos_id=-1```, this special token is disabled. Note that the unknown id cannot be disabled. We can define an id for padding (<pad>) as ```--pad_id=3```.
|
||||
|
||||
If you want to assign another special tokens, please see [Use custom symbols](doc/special_symbols.md).
|
||||
|
||||
### Vocabulary restriction
|
||||
```spm_encode``` accepts a ```--vocabulary``` and a ```--vocabulary_threshold``` option so that ```spm_encode``` will only produce symbols which also appear in the vocabulary (with at least some frequency). The background of this feature is described in [subword-nmt page](https://github.com/rsennrich/subword-nmt#best-practice-advice-for-byte-pair-encoding-in-nmt).
|
||||
|
||||
The usage is basically the same as that of ```subword-nmt```. Assuming that L1 and L2 are the two languages (source/target languages), train the shared spm model, and get resulting vocabulary for each:
|
||||
|
||||
```
|
||||
% cat {train_file}.L1 {train_file}.L2 | shuffle > train
|
||||
% spm_train --input=train --model_prefix=spm --vocab_size=8000 --character_coverage=0.9995
|
||||
% spm_encode --model=spm.model --generate_vocabulary < {train_file}.L1 > {vocab_file}.L1
|
||||
% spm_encode --model=spm.model --generate_vocabulary < {train_file}.L2 > {vocab_file}.L2
|
||||
```
|
||||
|
||||
```shuffle``` command is used just in case because ```spm_train``` loads the first 10M lines of corpus by default.
|
||||
|
||||
|
||||
Then segment train/test corpus with ```--vocabulary``` option
|
||||
```
|
||||
% spm_encode --model=spm.model --vocabulary={vocab_file}.L1 --vocabulary_threshold=50 < {test_file}.L1 > {test_file}.seg.L1
|
||||
% spm_encode --model=spm.model --vocabulary={vocab_file}.L2 --vocabulary_threshold=50 < {test_file}.L2 > {test_file}.seg.L2
|
||||
```
|
||||
|
||||
## Advanced topics
|
||||
|
||||
* [SentencePiece Experiments](doc/experiments.md)
|
||||
* [SentencePieceProcessor C++ API](doc/api.md)
|
||||
* [Use custom text normalization rules](doc/normalization.md)
|
||||
* [Use custom symbols](doc/special_symbols.md)
|
||||
* [Python Module](python/README.md)
|
||||
* [Segmentation and training algorithms in detail]
|
||||
|
||||
|
|
@ -1 +0,0 @@
|
|||
0.2.0
|
||||
|
|
@ -1,9 +0,0 @@
|
|||
#ifndef CONFIG_H_
|
||||
#define CONFIG_H_
|
||||
|
||||
#define VERSION "@PROJECT_VERSION@"
|
||||
#define PACKAGE "@PROJECT_NAME@"
|
||||
#define PACKAGE_STRING "@PROJECT_NAME@"
|
||||
|
||||
|
||||
#endif // CONFIG_H_
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,45 +0,0 @@
|
|||
#!/usr/bin/perl
|
||||
|
||||
# Copyright 2018 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Extract header files required for build protobuf-lite
|
||||
#
|
||||
# usage: ./extract_headers.pl *.cc
|
||||
|
||||
use strict;
|
||||
use warnings;
|
||||
|
||||
sub Process() {
|
||||
my $file = shift @_;
|
||||
if ($file =~ /\.h$/) {
|
||||
print "$file\n";
|
||||
}
|
||||
return unless open(F, $file);
|
||||
my @files = ();
|
||||
while (<F>) {
|
||||
chomp;
|
||||
if (/\#include <(google\/protobuf\/[^>]+)>/) {
|
||||
push @files, $1;
|
||||
}
|
||||
}
|
||||
close(F);
|
||||
for my $file (@files) {
|
||||
&Process($file);
|
||||
}
|
||||
}
|
||||
|
||||
for my $f (@ARGV) {
|
||||
&Process($f);
|
||||
}
|
||||
|
|
@ -1,175 +0,0 @@
|
|||
#!/usr/bin/perl
|
||||
|
||||
# Copyright 2018 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Generate spec_parser.h from sentencepiece_model.proto
|
||||
#
|
||||
# usage: ./gen_spec_parser.pl sentencepiece_model.proto > spec_parser.h
|
||||
|
||||
use strict;
|
||||
use warnings;
|
||||
|
||||
sub ProcessPrinter() {
|
||||
my ($filename) = @_;
|
||||
my $classname = "";
|
||||
my $valid = 0;
|
||||
my %enum;
|
||||
open(F, $filename) || die;
|
||||
print "namespace {\n";
|
||||
while (<F>) {
|
||||
chomp;
|
||||
if (/^\s*message (\S+)/) {
|
||||
$classname = $1;
|
||||
$valid = 0;
|
||||
if ($classname =~ /(TrainerSpec|NormalizerSpec)/) {
|
||||
print "inline std::string PrintProto(const $classname &message) {\n";
|
||||
print " std::ostringstream os;\n\n";
|
||||
print " os << \"$classname {\\n\";\n";
|
||||
$valid = 1;
|
||||
}
|
||||
} elsif (/^\s*}/) {
|
||||
next if (!$valid);
|
||||
print " os << \"}\\n\";\n";
|
||||
print "\n return os.str();\n";
|
||||
print "}\n\n";
|
||||
} elsif (/enum\s*(\S+)/) {
|
||||
my $name = $1;
|
||||
$enum{$name} = 1;
|
||||
next if (!$valid);
|
||||
print " static const std::map<$classname::$name, std::string> k${name}_Map = { ";
|
||||
while (<F>) {
|
||||
if (/(\S+)\s*=\s*(\d+)/) {
|
||||
print "{$classname::$1, \"$1\"}, ";
|
||||
} elsif (/}/) {
|
||||
print " };\n";
|
||||
last;
|
||||
}
|
||||
}
|
||||
} elsif (/\s*(repeated|optional)\s+(\S+)\s+(\S+)\s*=\s*(\d+)/) {
|
||||
next if (/deprecated = true/);
|
||||
next if (!$valid);
|
||||
my $opt = $1;
|
||||
my $type = $2;
|
||||
my $name = $3;
|
||||
if ($type =~ /(int|double|float|bool|string)/) {
|
||||
if ($opt eq "optional") {
|
||||
print " os << \" ${name}: \" << message.${name}() << \"\\n\";\n";
|
||||
} else {
|
||||
print " for (const auto &v : message.${name}())\n";
|
||||
print " os << \" ${name}: \" << v << \"\\n\";\n";
|
||||
}
|
||||
} elsif (defined $enum{$type}) {
|
||||
if ($opt eq "optional") {
|
||||
print " {\n";
|
||||
print " const auto it = k${type}_Map.find(message.${name}());\n";
|
||||
print " if (it == k${type}_Map.end())\n";
|
||||
print " os << \" ${name}: unknown\\n\";\n";
|
||||
print " else\n";
|
||||
print " os << \" ${name}: \" << it->second << \"\\n\";\n";
|
||||
print " }\n";
|
||||
} else {
|
||||
print " for (const auto &v : message.${name}()) {\n";
|
||||
print " const auto it = k${type}_Map.find(v);\n";
|
||||
print " if (it == k${type}_Map.end())\n";
|
||||
print " os << \" ${name}: unknown\\n\";\n";
|
||||
print " else\n";
|
||||
print " os << \" ${name}: \" << it->second << \"\\n\";\n";
|
||||
print " }\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
print "} // namespace\n\n";
|
||||
close(F);
|
||||
}
|
||||
|
||||
sub ProcessParser() {
|
||||
my ($filename) = @_;
|
||||
my $classname = "";
|
||||
my $valid = 0;
|
||||
my %enum;
|
||||
open(F, $filename) || die;
|
||||
while (<F>) {
|
||||
if (/^\s*message (\S+)/) {
|
||||
$classname = $1;
|
||||
$valid = 0;
|
||||
if ($classname =~ /(TrainerSpec|NormalizerSpec)/) {
|
||||
print "util::Status SentencePieceTrainer::SetProtoField(const std::string& name, const std::string& value, $classname *message) {\n";
|
||||
print " CHECK_OR_RETURN(message);\n\n";
|
||||
$valid = 1;
|
||||
}
|
||||
} elsif (/^\s*}/) {
|
||||
next if (!$valid);
|
||||
print " return util::StatusBuilder(util::error::NOT_FOUND)\n";
|
||||
print " << \"unknown field name \\\"\" << name << \"\\\" in ${classname}.\";\n";
|
||||
print "}\n\n";
|
||||
} elsif (/enum\s*(\S+)/) {
|
||||
my $name = $1;
|
||||
$enum{$name} = 1;
|
||||
next if (!$valid);
|
||||
print " static const std::map <std::string, $classname::$name> k${name}_Map = { ";
|
||||
while (<F>) {
|
||||
if (/(\S+)\s*=\s*(\d+)/) {
|
||||
print "{\"$1\", $classname::$1}, ";
|
||||
} elsif (/}/) {
|
||||
print " };\n\n";
|
||||
last;
|
||||
}
|
||||
}
|
||||
} elsif (/\s*(repeated|optional)\s+(\S+)\s+(\S+)\s*=\s*(\d+)/) {
|
||||
next if (/deprecated = true/);
|
||||
next if (!$valid);
|
||||
my $opt = $1;
|
||||
my $type = $2;
|
||||
my $name = $3;
|
||||
my $func_prefix = $opt eq "optional" ? "set_" : "add_";
|
||||
my $body = "";
|
||||
if ($type =~ /(int|double|float|bool)/) {
|
||||
my $empty = $type eq "bool" ? "\"true\"" : "\"\"";
|
||||
$body =
|
||||
"${type} v;\n" .
|
||||
" if (!string_util::lexical_cast(val.empty() ? ${empty} : val, &v))\n" .
|
||||
" return util::StatusBuilder(util::error::INVALID_ARGUMENT) << \"cannot parse \\\"\" << val << \"\\\" as ${type}.\";\n" .
|
||||
" message->${func_prefix}${name}(v);\n";
|
||||
} elsif ($type =~ /string/) {
|
||||
$body = "message->${func_prefix}${name}(val);\n";
|
||||
} elsif ($type =~ /bytes/) {
|
||||
$body = "message->${func_prefix}${name}(val.data(), val.size());\n";
|
||||
} elsif (defined $enum{$type}) {
|
||||
$body = "const auto it = k${type}_Map.find(string_util::ToUpper(val));\n" .
|
||||
" if (it == k${type}_Map.end())\n" .
|
||||
" return util::StatusBuilder(util::error::INVALID_ARGUMENT) << \"unknown enumeration value of \\\"\" << val << \"\\\" as ${type}.\";\n" .
|
||||
" message->${func_prefix}${name}(it->second);\n";
|
||||
}
|
||||
print " if (name == \"${name}\") {\n";
|
||||
if ($opt eq "repeated") {
|
||||
print " for (const auto &val : string_util::Split(value, \",\")) {\n";
|
||||
print " ${body}";
|
||||
print " }\n";
|
||||
} else {
|
||||
print " const auto &val = value;\n";
|
||||
print " ${body}";
|
||||
}
|
||||
print " return util::OkStatus();\n";
|
||||
print " }\n\n";
|
||||
}
|
||||
}
|
||||
close(F);
|
||||
}
|
||||
|
||||
for my $file (@ARGV) {
|
||||
&ProcessPrinter($file);
|
||||
&ProcessParser($file);
|
||||
}
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
#!/usr/bin/perl
|
||||
|
||||
# Copyright 2016 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Generate unicode_sciript_data.h from Unicode Scripts.txt
|
||||
#
|
||||
# usage: ./gen_unicode_Scripts_code.pl < scripts > unicode_script_data.h
|
||||
#
|
||||
print "#ifndef UNICODE_SCRIPT_DATA_H_\n";
|
||||
print "#define UNICODE_SCRIPT_DATA_H_\n";
|
||||
print "namespace sentencepiece {\n";
|
||||
print "namespace unicode_script {\n";
|
||||
print "namespace {\n";
|
||||
print "void InitTable(std::unordered_map<char32, ScriptType> *smap) {\n";
|
||||
print " CHECK_NOTNULL(smap)->clear();\n";
|
||||
|
||||
while (<>) {
|
||||
chomp;
|
||||
if (/^([0-9A-F]+)\s+;\s+(\S+)\s+\#/) {
|
||||
printf(" (*smap)[0x%s] = U_%s;\n", $1, $2);
|
||||
} elsif (/^([0-9A-F]+)\.\.([0-9A-F]+)\s+;\s+(\S+)\s+\#/) {
|
||||
printf(" for (char32 c = 0x%s; c <= 0x%s; ++c)\n", $1, $2);
|
||||
printf(" (*smap)[c] = U_%s;\n", $3);
|
||||
} else {
|
||||
next;
|
||||
}
|
||||
}
|
||||
|
||||
print "}\n";
|
||||
print "} // namespace\n";
|
||||
print "} // namespace unicode_script\n";
|
||||
print "} // namespace sentencepiece\n";
|
||||
print "#endif // UNICODE_SCRIPT_DATA_H_\n";
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
|
|
@ -1,129 +0,0 @@
|
|||
# SentencePieceProcessor C++ API
|
||||
|
||||
## Load SentencePiece model
|
||||
To start working with the SentencePiece model, you will want to include the `sentencepiece_processor.h` header file.
|
||||
Then instantiate sentencepiece::SentencePieceProcessor class and calls `Load` method to load the model with file path or std::istream.
|
||||
|
||||
```C++
|
||||
#include <sentencepiece_processor.h>
|
||||
|
||||
sentencepiece::SentencePieceProcessor processor;
|
||||
const auto status = processor.Load("//path/to/model.model");
|
||||
if (!status.ok()) {
|
||||
std::cerr << status.ToString() << std::endl;
|
||||
// error
|
||||
}
|
||||
|
||||
// You can also load a serialized model from std::string.
|
||||
// const std::stirng str = // Load blob contents from a file.
|
||||
// auto status = processor.LoadFromSerializedProto(str);
|
||||
```
|
||||
|
||||
## Tokenize text (preprocessing)
|
||||
Calls `SentencePieceProcessor::Encode` method to tokenize text.
|
||||
|
||||
```C++
|
||||
std::vector<std::string> pieces;
|
||||
processor.Encode("This is a test.", &pieces);
|
||||
for (const std::string &token : pieces) {
|
||||
std::cout << token << std::endl;
|
||||
}
|
||||
```
|
||||
|
||||
You will obtain the sequence of vocab ids as follows:
|
||||
|
||||
```C++
|
||||
std::vector<int> ids;
|
||||
processor.Encode("This is a test.", &ids);
|
||||
for (const int id : ids) {
|
||||
std::cout << id << std::endl;
|
||||
}
|
||||
```
|
||||
|
||||
## Detokenize text (postprocessing)
|
||||
Calls `SentencePieceProcessor::Decode` method to detokenize a sequence of pieces or ids into a text. Basically it is guaranteed that the detokenization is an inverse operation of Encode, i.e., `Decode(Encode(Normalize(input))) == Normalize(input)`.
|
||||
|
||||
```C++
|
||||
std::vector<std::string> pieces = { "▁This", "▁is", "▁a", "▁", "te", "st", "." }; // sequence of pieces
|
||||
std::string text
|
||||
processor.Decode(pieces, &text);
|
||||
std::cout << text << std::endl;
|
||||
|
||||
std::vector<int> ids = { 451, 26, 20, 3, 158, 128, 12 }; // sequence of ids
|
||||
processor.Decode(ids, &text);
|
||||
std::cout << text << std::endl;
|
||||
```
|
||||
|
||||
## Sampling (subword regularization)
|
||||
Calls `SentencePieceProcessor::SampleEncode` method to sample one segmentation.
|
||||
|
||||
```C++
|
||||
std::vector<std::string> pieces;
|
||||
processor.SampleEncode("This is a test.", &pieces, -1, 0.2);
|
||||
|
||||
std::vector<int> ids;
|
||||
processor.SampleEncode("This is a test.", &ids, -1, 0.2);
|
||||
```
|
||||
SampleEncode has two sampling parameters, `nbest_size` and `alpha`, which correspond to `l` and `alpha` in the [original paper](https://arxiv.org/abs/1804.10959). When `nbest_size` is -1, one segmentation is sampled from all hypothesis with forward-filtering and backward sampling algorithm.
|
||||
|
||||
## Training
|
||||
Calls `SentencePieceTrainer::Train` function to train sentencepiece model. You can pass the same parameters of [spm_train](https://github.com/google/sentencepiece#train-sentencepiece-model) as a single string.
|
||||
|
||||
```C++
|
||||
#include <sentencepiece_trainer.h>
|
||||
|
||||
sentencepiece::SentencePieceTrainer::Train("--input=test/botchan.txt --model_prefix=m --vocab_size=1000");
|
||||
```
|
||||
|
||||
## ImmutableSentencePieceText
|
||||
You will want to use `ImmutableSentencePieceText` class to obtain the pieces and ids at the same time.
|
||||
This proto also encodes a utf8-byte offset of each piece over user input or detokenized text.
|
||||
|
||||
```C++
|
||||
#include <sentencepiece_processor.h>
|
||||
|
||||
sentencepiece::ImmutableSentencePieceText spt;
|
||||
|
||||
// Encode
|
||||
processor.Encode("This is a test.", spt.mutable_proto());
|
||||
|
||||
// or
|
||||
// spt = processor.EncodeAsImmutableProto("This is a test.");
|
||||
|
||||
std::cout << spt.text() << std::endl; // This is the same as the input.
|
||||
for (const auto &piece : spt.pieces()) {
|
||||
std::cout << piece.begin() << std::endl; // beginning of byte offset
|
||||
std::cout << piece.end() << std::endl; // end of byte offset
|
||||
std::cout << piece.piece() << std::endl; // internal representation.
|
||||
std::cout << piece.surface() << std::endl; // external representation. spt.text().substr(begin, end - begin) == surface().
|
||||
std::cout << piece.id() << std::endl; // vocab id
|
||||
}
|
||||
|
||||
// Decode
|
||||
processor.Decode({10, 20, 30}, spt.mutable_proto());
|
||||
std::cout << spt.text() << std::endl; // This is the same as the decoded string.
|
||||
for (const auto &piece : spt.pieces()) {
|
||||
// the same as above.
|
||||
}
|
||||
```
|
||||
|
||||
## Vocabulary management
|
||||
You will want to use the following methods to obtain ids from/to pieces.
|
||||
|
||||
```C++
|
||||
processor.GetPieceSize(); // returns the size of vocabs.
|
||||
processor.PieceToId("foo"); // returns the vocab id of "foo"
|
||||
processor.IdToPiece(10); // returns the string representation of id 10.
|
||||
processor.IsUnknown(0); // returns true if the given id is an unknown token. e.g., <unk>
|
||||
processor.IsControl(10); // returns true if the given id is a control token. e.g., <s>, </s>
|
||||
```
|
||||
|
||||
## Extra Options
|
||||
Use `SetEncodeExtraOptions` and `SetDecodeExtraOptions` methods to set extra options for encoding and decoding respectively. These methods need to be called just after `Load` methods.
|
||||
|
||||
```C++
|
||||
processor.SetEncodeExtraOptions("bos:eos"); // add <s> and </s>.
|
||||
processor.SetEncodeExtraOptions("reverse:bos:eos"); // reverse the input and then add <s> and </s>.
|
||||
|
||||
processor.SetDecodeExtraOptions("reverse"); // the decoder's output is reversed.
|
||||
```
|
||||
|
|
@ -1,145 +0,0 @@
|
|||
# SentencePiece Experiments
|
||||
|
||||
## Experiments 1 (subword vs word-based model)
|
||||
### Experimental settings
|
||||
|
||||
* Segmentation algorithms:
|
||||
* **SentencePiece**: SentencePiece with a language-model based segmentation. (`--model_type=unigram`)
|
||||
* **SentencePeice(BPE)**: SentencePiece with Byte Pair Encoding. [[Sennrich et al.](http://www.aclweb.org/anthology/P16-1162)]] (`--model_type=bpe`)
|
||||
* **Moses**: [Moses tokenizer](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl) for English.
|
||||
* **KyTea**: [KyTea](http://www.phontron.com/kytea/) for Japanese.
|
||||
* **MeCab**: [MeCab](http://taku910.github.io/mecab/) for Japanese.
|
||||
* **neologd**: [MeCab with neologd](https://github.com/neologd/mecab-ipadic-neologd) for Japanese.
|
||||
* **(Moses/KyTea)+SentencePiece**: Apply SentencePiece (Unigram) to pre-tokenized sentences. We have several variants with different tokenizers., e.g., **(Moses/MeCab)+SentencePiece**, **(MeCab/Moses)+SentencePiece**.
|
||||
* *char**: Segments sentence by characters.
|
||||
|
||||
* Data sets:
|
||||
* [KFTT](http://www.phontron.com/kftt/index.html)
|
||||
|
||||
* NMT parameters: ([Google’s Neural Machine Translation System](https://arxiv.org/pdf/1609.08144.pdf) is applied for all experiments.)
|
||||
* Dropout prob: 0.2
|
||||
* num nodes: 512
|
||||
* num lstms: 6
|
||||
* Decoder parameters (α and β) are optimized with development data.
|
||||
|
||||
* Evaluation metrics:
|
||||
* Case-sensitive BLEU on detokenized text with NIST scorer and KyTea segmenter. Used in-house rule-based detokenizer for Moses/KyTea/MeCab/neologd.
|
||||
|
||||
|
||||
### Results (BLEU scores)
|
||||
#### English to Japanese
|
||||
|Setting|vocab size|BLEU(dev)|BLEU(test)|src #tokens/sent.|trg #tokens/sent.|
|
||||
|:---|---:|---:|---:|---:|---:|
|
||||
|SentencePiece|4k (shared)|0.2857|0.2940|43.7478|29.6998|
|
||||
|SentencePiece|8k (shared)|0.2785|0.2955|30.9734|25.0540|
|
||||
|SentencePiece|16k (shared)|0.2664|0.2862|27.1827|21.5326|
|
||||
|SentencePiece|32k (shared)|0.2641|0.2849|25.0592|19.0840|
|
||||
|SentencePiece(BPE)|8k (shared)|0.2767|0.2947|31.7693|25.4331|
|
||||
|(Moses/KyTea)+SentencePiece|8k (shared)|0.2900|0.2985|31.2719|29.9854|
|
||||
|(Moses/MeCab)+SentencePiece|8k (shared)|0.2817|0.2950|31.4743|28.9537|
|
||||
|(Moses/neologd)+SentencePiece|8k (shared)|0.2824|**0.3062**|31.2985|28.8645|
|
||||
|Moses/Kytea|80k/80k|0.2576|0.2824|21.2513|23.2161|
|
||||
|Moses/MeCab|80k/80k|0.2455|0.2780|21.2513|21.2033|
|
||||
|Moses/neologd|80k/80k|0.2157|0.2378|21.2513|18.4768|
|
||||
|Moses/SentencePiece|80k/8k|0.2475|0.2742|21.2513|22.9383|
|
||||
|SentencePiece/KyTea|8k/80k|0.2778|0.2918|27.0429|23.2161|
|
||||
|SentencePiece/MeCab|8k/80k|0.2673|0.2919|27.0429|21.2033|
|
||||
|SentencePiece/neolgod|8k80k|0.2280|0.2494|27.0429|18.4768|
|
||||
|Char|3k (shared)|0.2509|0.2679|109.8662|33.6963|
|
||||
|
||||
#### Japanese to English
|
||||
|Setting|vocab size|BLEU(dev)|BLEU(test)|src #tokens/sent.|trg #tokens/sent.|
|
||||
|:---|---:|---:|---:|---:|---:|
|
||||
|SentencePiece|4k (shared)|0.1970|**0.2179**|29.6998|43.7478|
|
||||
|SentencePiece|8k (shared)|0.1966|0.2162|25.0540|30.9734|
|
||||
|SentencePiece|16k (shared)|0.1996|0.2160|21.5326|27.1827|
|
||||
|SentencePiece|32k (shared)|0.1949|0.2159|19.0840|25.0592|
|
||||
|SentencePiece(BPE)|8k (shared)|0.1977|0.2173|25.4331|31.7693|
|
||||
|(KyTea/Moses)+SentencePiece|8k (shared)|0.1921|0.2086|29.9854|31.2719|
|
||||
|(MeCab/Moses)+SentencePiece|8k (shared)|0.1909|0.2049|28.9537|31.4743|
|
||||
|(neologd/Moses)+SentencePiece|8k (shared)|0.1938|0.2137|28.8645|31.2985|
|
||||
|KyTea/Moses|80k/80k|0.1707|0.2006|23.2161|21.2513|
|
||||
|MeCab/Moses|80k/80k|0.1668|0.1892|21.2033|21.2513|
|
||||
|neologd/Moses|80k/80k|0.1589|0.1836|18.4768|21.2513|
|
||||
|SentencePiece/Moses|8k/80k|0.1727|0.1994|22.9383|21.2513|
|
||||
|KyTea/SentencePiece|80k/8k|0.1939|0.2141|23.2161|27.0429|
|
||||
|MeCab/SentencePiece|80k/8k|0.1892|0.2077|21.2033|27.0429|
|
||||
|neologd/SentencePiece|80k/8k|0.1641|0.1804|18.4768|27.0429|
|
||||
|Char|3k (shared)|0.0824|0.0918|33.6963|109.8662|
|
||||
|
||||
#### Discussion
|
||||
* **SentencePiece (Unigram/BPE)** outperforms word-based methods **(Moses/KyTea/MeCab/neologd)** even with a smaller vocabulary (10% of word-based methods).
|
||||
* The number of tokens to represent Japanese sentences is almost comparable between **SentencePiece (unigram)** and **KyTea**, though the vocabulary of **SentencePiece** is much smaller. It implies that Sentencepiece can effectively compress the sentences with a smaller vocabulary set.
|
||||
* Pretokenization can slightly improve the BLEU scores in English to Japanese. In Japanese to English translation, pretokenization doesn't help to improve BLEU.
|
||||
* **Neologd** shows poor BLEU score. Tokenizing sentences with a large named entity dictionary might not be effective in neural-based text processing.
|
||||
* **SentencePiece(Unigram)** shows slightly better text compression ratio than **BPE**, but no significant differences in BLEU score.
|
||||
* The selection of vocabulary size for SentencePiece is sensitive in English to Japanese. This is probably because the vocabulary size will drastically affect the tokenization results in Japanese which has no explicit spaces between words.
|
||||
|
||||
## Experiments 2 (subwording with various pre-tokenizations)
|
||||
### Experimental settings
|
||||
We have evaluated SentencePiece segmentation with the following configurations.
|
||||
|
||||
* Segmentation algorithms:
|
||||
* **BPE** (Byte Pair
|
||||
Encoding) [[Sennrich et al.](http://www.aclweb.org/anthology/P16-1162)]] (`--model_type=bpe`)
|
||||
* **Unigram**. Language-model based segmentation. (`--model_type=unigram`)
|
||||
|
||||
* pretokenization methods:
|
||||
* **NoPretok**: No pretokenization. We train SentencePiece directly from
|
||||
raw sentences (`--split_by_whitespace=false`).
|
||||
* **WsPretok**: Trains SentencePiece model from the sentences tokenized by
|
||||
whitespaces (`--split_by_whitespace=true`). When handling CJK, this setting is almost equivalent to **NoPretok**.
|
||||
* **MosesPretok**: Trains SentencePiece model from sentences tokenized
|
||||
by [Moses tokenizer](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/tokenizer.perl). We used [KyTea](http://www.phontron.com/kytea/) for
|
||||
Japanese and in-house segmenters for Korean and Chinese respectively.
|
||||
|
||||
* NMT parameters: ([Google’s Neural Machine Translation System](https://arxiv.org/pdf/1609.08144.pdf) is applied for all experiments.)
|
||||
* 16k shared vocabulary (Shares the same vocabulary for source and
|
||||
target. We train single SentencePiece model by concatenating raw source
|
||||
and target sentences.)
|
||||
* Dropout prob: 0.2
|
||||
* num nodes: 512
|
||||
* num lstms: 8
|
||||
|
||||
* Evaluation metrics:
|
||||
* Case-sensitive BLEU on detokenized text with NIST scorer.
|
||||
* For CJK, the same word segmenters are applied prior to NIST scorer.
|
||||
* No detokenizer is applied for **NoPretok** and **WsPretok**, which can
|
||||
directly emit detokenized sentences.
|
||||
* Applied [Moses detokenizer](https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/detokenizer.perl) and in-house rule-based detokenizer (CJK) for **MosesPretok**.
|
||||
|
||||
* Data sets:
|
||||
* [KFTT](http://www.phontron.com/kftt/index.html)
|
||||
* [MultiUN](http://opus.lingfil.uu.se/MultiUN.php) (First 5M and next
|
||||
5k/5k sentences are used for training and development/testing respectively.)
|
||||
* [WMT16](https://www.statmt.org/wmt16/)
|
||||
* In-house: (Used 5M parallel sentences for training)
|
||||
|
||||
**NoPretok** and **WsPretok** do not use any language-dependent resources.
|
||||
**BPE**+**MosePretok** is almost the same configuration used in [[Sennrich et al.](http://www.aclweb.org/anthology/P16-1162)] and [[Wu et al.](https://arxiv.org/pdf/1609.08144.pdf)].
|
||||
|
||||
### Results (BLEU scores)
|
||||
|Language Pair|BPE(NoPretok)|BPE(WsPretok)|BPE(MosesPretok)|Unigram(NoPretok)|Unigram(WsPretok)|Unigram(MosesPretok)
|
||||
|---|---|---|---|---|---|---|
|
||||
|KFTT en-ja| 0.2796| 0.281| 0.286| 0.2806| 0.280| 0.2871|
|
||||
|KFTT ja-en| 0.1943| 0.208| 0.1967| 0.1985| 0.2148| 0.198|
|
||||
|MultiUN ar-en| 0.5268| 0.5414| 0.5381| 0.5317| 0.5449| 0.5401|
|
||||
|MultiUN en-ar| 0.4039| 0.4147| 0.4012| 0.4084| 0.4172| 0.3991|
|
||||
|MultiUN en-zh| 0.4155| 0.4186| 0.395| 0.4214| 0.4165| 0.399|
|
||||
|MultiUN zh-en| 0.46| 0.4716| 0.4806| 0.4644| 0.4711| 0.4759|
|
||||
|In house en-ko| 0.178| 0.1851| 0.1893| 0.1846| 0.1872| 0.1890|
|
||||
|In house ko-en| 0.1786| 0.1954| 0.1994| 0.1845| 0.1956| 0.2015|
|
||||
|WMT16 cs-en| 0.1987| 0.2252| 0.2231| 0.2164| 0.2228| 0.2238|
|
||||
|WMT16 de-en| 0.3194| 0.3348| 0.3374| 0.3261| 0.3375| 0.3398|
|
||||
|WMT16 en-cs| 0.1607| 0.1827| 0.1812| 0.1722| 0.1778| 0.179|
|
||||
|WMT16 en-de| 0.2847| 0.3029| 0.3013| 0.2946| 0.3000| 0.3053|
|
||||
|WMT16 en-fi| 0.1434| 0.1528| 0.1499| 0.1472| 0.1568| 0.1517|
|
||||
|WMT16 en-ru| 0.1884| 0.1973| 0.1989| 0.19| 0.1982| 0.1903|
|
||||
|WMT16 fi-en| 0.1775| 0.1867| 0.1877| 0.182| 0.1882| 0.1865|
|
||||
|WMT16 ru-en| 0.2042| 0.2229| 0.2194| 0.2087| 0.2201| 0.2155|
|
||||
|
||||
* **MosesPretok** does not always improve BLEU scores. Comparable
|
||||
accuracy can be obtained without using language-dependent resources in many
|
||||
language pairs.
|
||||
* Whitespace pretokenization is a reasonable choice. It does not use language-specific resources.
|
||||
* **NoPretok** shows poor BLEU scores. Unigrams are more robust than BPE when no pretokenizer is applied.
|
||||
|
|
@ -1,50 +0,0 @@
|
|||
# Use custom normalization rule
|
||||
By default, SentencePiece normalizes the input sentence with a variant of Unicode
|
||||
[NFKC](https://en.wikipedia.org/wiki/Unicode_equivalence).
|
||||
|
||||
SentencePiece allows us to define custom normalization rule, which is stored in the model file.
|
||||
|
||||
## Use pre-defined normalization rule
|
||||
SentencePiece provides the following pre-defined normalization rule. It is recommended to use one of them unless you have any special reasons.
|
||||
|
||||
* **nmt_nfkc**: [NFKC](https://en.wikipedia.org/wiki/Unicode_equivalence) normalization with some additional normalization around spaces. (default)
|
||||
* **nfkc**: original NFKC normalization.
|
||||
* **nmt_nfkc_cf**: nmt_nfkc + [Unicode case folding](https://www.w3.org/International/wiki/Case_folding) (mostly lower casing)
|
||||
* **nfkc_cf**: nfkc + [Unicode case folding](https://www.w3.org/International/wiki/Case_folding).
|
||||
* **identity**: no normalization
|
||||
|
||||
You can choose the normalization rule with `--normalization_rule_name` flag.
|
||||
```
|
||||
% spm_train --normalization_rule_name=identity --input=<input> --model_prefix=<model file> --vocab_size=8000
|
||||
```
|
||||
|
||||
NOTE: Due to the limitation of normalization algorithm, full NFKC normalization is not implemented. [builder.h] describes example character sequences not normalized by our NFKC implementation.
|
||||
|
||||
The difference between **nmt_nfkc** and **nfkc** can be found via ```diff -u data/nfkc.tsv data/nmt_nfkc.tsv``` command.
|
||||
|
||||
## Use custom normalization rule
|
||||
The normalization is performed with user-defined string-to-string mappings and leftmost longest matching.
|
||||
|
||||
You can use custom normalization rule by preparing a TSV file formatted as follows:
|
||||
```
|
||||
41 302 300 1EA6
|
||||
41 302 301 1EA4
|
||||
41 302 303 1EAA
|
||||
...
|
||||
```
|
||||
In this sample, UCS4 sequence [41 302 300] (hex) is converted into [1EA6] (hex). When there are ambiguities in the conversions, the longest rule is used.
|
||||
Note that the tab is used as a delimiter for source and target sequence and space is used as a delimiter for UCS4 characters. We can make the target sequence empty to remove some specific characters from the text.
|
||||
See [data/nfkc.tsv](../data/nfkc.tsv) as an example. Once a TSV file is prepared, you can specify it with `--normalization_rule_tsv` flag.
|
||||
```
|
||||
% spm_train --normalization_rule_tsv=<rule tsv file> --input=<input> --model_prefix=<model file> --vocab_size=8000
|
||||
```
|
||||
|
||||
`<model file>` embeds the normalization rule so the same normalization rule is applied when `<model file>` is used.
|
||||
|
||||
|
||||
## Command line tool to perform normalization
|
||||
```
|
||||
% spm_normalize --model=<model_file> file1 file2..
|
||||
% spm_normalize --normalization_rule_tsv=custom.tsv file1 file2..
|
||||
```
|
||||
The first command line uses the normalization rule embedded in the model file. The second command line uses the normalization rule in TSV file and is useful to make normalization rule interactively.
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
# Training options
|
||||
|
||||
The training options for the `spm_train` can be listed using `spm_train --help`. Since the standard `pip install` of sentencepiece does not necessarily install `spm_train`, the options are also listed here.
|
||||
|
||||
```
|
||||
Usage: ../build/src/spm_train [options] files
|
||||
|
||||
--input (comma separated list of input sentences) type: std::string default: ""
|
||||
--input_format (Input format. Supported format is `text` or `tsv`.) type: std::string default: ""
|
||||
--model_prefix (output model prefix) type: std::string default: ""
|
||||
--model_type (model algorithm: unigram, bpe, word or char) type: std::string default: "unigram"
|
||||
--vocab_size (vocabulary size) type: int32 default: 8000
|
||||
--accept_language (comma-separated list of languages this model can accept) type: std::string default: ""
|
||||
--self_test_sample_size (the size of self test samples) type: int32 default: 0
|
||||
--character_coverage (character coverage to determine the minimum symbols) type: double default: 0.9995
|
||||
--input_sentence_size (maximum size of sentences the trainer loads) type: std::uint64_t default: 0
|
||||
--shuffle_input_sentence (Randomly sample input sentences in advance. Valid when --input_sentence_size > 0) type: bool default: true
|
||||
--seed_sentencepiece_size (the size of seed sentencepieces) type: int32 default: 1000000
|
||||
--shrinking_factor (Keeps top shrinking_factor pieces with respect to the loss) type: double default: 0.75
|
||||
--num_threads (number of threads for training) type: int32 default: 16
|
||||
--num_sub_iterations (number of EM sub-iterations) type: int32 default: 2
|
||||
--max_sentencepiece_length (maximum length of sentence piece) type: int32 default: 16
|
||||
--max_sentence_length (maximum length of sentence in byte) type: int32 default: 4192
|
||||
--split_by_unicode_script (use Unicode script to split sentence pieces) type: bool default: true
|
||||
--split_by_number (split tokens by numbers (0-9)) type: bool default: true
|
||||
--split_by_whitespace (use a white space to split sentence pieces) type: bool default: true
|
||||
--split_digits (split all digits (0-9) into separate pieces) type: bool default: false
|
||||
--treat_whitespace_as_suffix (treat whitespace marker as suffix instead of prefix.) type: bool default: false
|
||||
--allow_whitespace_only_pieces (allow pieces that only contain (consecutive) whitespace tokens) type: bool default: false
|
||||
--control_symbols (comma separated list of control symbols) type: std::string default: ""
|
||||
--control_symbols_file (load control_symbols from file.) type: std::string default: ""
|
||||
--user_defined_symbols (comma separated list of user defined symbols) type: std::string default: ""
|
||||
--user_defined_symbols_file (load user_defined_symbols from file.) type: std::string default: ""
|
||||
--required_chars (UTF8 characters in this flag are always used in the character set regardless of --character_coverage) type: std::string default: ""
|
||||
--required_chars_file (load required_chars from file.) type: std::string default: ""
|
||||
--byte_fallback (decompose unknown pieces into UTF-8 byte pieces) type: bool default: false
|
||||
--vocabulary_output_piece_score (Define score in vocab file) type: bool default: true
|
||||
--normalization_rule_name (Normalization rule name. Choose from nfkc or identity) type: std::string default: "nmt_nfkc"
|
||||
--normalization_rule_tsv (Normalization rule TSV file. ) type: std::string default: ""
|
||||
--denormalization_rule_tsv (Denormalization rule TSV file.) type: std::string default: ""
|
||||
--add_dummy_prefix (Add dummy whitespace at the beginning of text) type: bool default: true
|
||||
--remove_extra_whitespaces (Removes leading, trailing, and duplicate internal whitespace) type: bool default: true
|
||||
--hard_vocab_limit (If set to false, --vocab_size is considered as a soft limit.) type: bool default: true
|
||||
--use_all_vocab (If set to true, use all tokens as vocab. Valid for word/char models.) type: bool default: false
|
||||
--unk_id (Override UNK (<unk>) id.) type: int32 default: 0
|
||||
--bos_id (Override BOS (<s>) id. Set -1 to disable BOS.) type: int32 default: 1
|
||||
--eos_id (Override EOS (</s>) id. Set -1 to disable EOS.) type: int32 default: 2
|
||||
--pad_id (Override PAD (<pad>) id. Set -1 to disable PAD.) type: int32 default: -1
|
||||
--unk_piece (Override UNK (<unk>) piece.) type: std::string default: "<unk>"
|
||||
--bos_piece (Override BOS (<s>) piece.) type: std::string default: "<s>"
|
||||
--eos_piece (Override EOS (</s>) piece.) type: std::string default: "</s>"
|
||||
--pad_piece (Override PAD (<pad>) piece.) type: std::string default: "<pad>"
|
||||
--unk_surface (Dummy surface string for <unk>. In decoding <unk> is decoded to `unk_surface`.) type: std::string default: " ⁇ "
|
||||
--train_extremely_large_corpus (Increase bit depth for unigram tokenization.) type: bool default: false
|
||||
--random_seed (Seed value for random generator.) type: uint32 default: 4294967295
|
||||
--enable_differential_privacy (Whether to add DP while training. Currently supported only by UNIGRAM model.) type: bool default: false
|
||||
--differential_privacy_noise_level (Amount of noise to add for DP) type: float default: 0
|
||||
--differential_privacy_clipping_threshold (Threshold for clipping the counts for DP) type: std::uint64_t default: 0
|
||||
--help (show help) type: bool default: false
|
||||
--version (show version) type: bool default: false
|
||||
--minloglevel (Messages logged at a lower level than this don't actually get logged anywhere) type: int default: 0
|
||||
```
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
# Use custom symbols
|
||||
SentencePiece model supports two types of special symbols.
|
||||
|
||||
## Control symbol
|
||||
Control symbols are used to encode special indicators for the decoder to change the behavior dynamically.
|
||||
Example includes the language indicators in multi-lingual models. `<s>` and `</s>` are reserved control symbols.
|
||||
Control symbols must be inserted outside of the SentencePiece segmentation. Developers need to take the responsibility to insert these symbols in data generation and decoding.
|
||||
|
||||
It is guaranteed that control symbols have no corresponding surface strings in the original user input. Control symbols are decoded into empty strings.
|
||||
|
||||
## User defined symbol
|
||||
User defined symbol is handled as one piece in any context. If this symbol is included in the input text, this symbol is always extracted as one piece.
|
||||
|
||||
## Specify special symbols in training time
|
||||
Use `--control_symbols` and `--user_defined_symbols` flags as follows
|
||||
|
||||
```
|
||||
% spm_train --control_symbols=<foo>,<bar> --user_defined_symbols=<user1>,<user2> --input=<input file> --model_prefix=<model file> --vocab_size=8000
|
||||
```
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
/*.so
|
||||
/build
|
||||
/*.pickle
|
||||
|
|
@ -1,4 +0,0 @@
|
|||
recursive-include test *.py *.model botchan.txt
|
||||
recursive-include src *.i
|
||||
recursive-include sentencepiece *
|
||||
include *.md VERSION.* build_bundled.sh
|
||||
|
|
@ -1,183 +0,0 @@
|
|||
# SentencePiece Python Wrapper
|
||||
|
||||
Python wrapper for SentencePiece. This API will offer the encoding, decoding and training of Sentencepiece.
|
||||
|
||||
## Build and Install SentencePiece
|
||||
For Linux (x64/i686), macOS, and Windows(win32/x64) environment, you can simply use pip command to install SentencePiece python module.
|
||||
|
||||
```
|
||||
% pip install sentencepiece
|
||||
```
|
||||
|
||||
To build and install the Python wrapper from source, try the following commands to build and install wheel package.
|
||||
```
|
||||
% git clone https://github.com/google/sentencepiece.git
|
||||
% cd sentencepiece
|
||||
% mkdir build
|
||||
% cd build
|
||||
% cmake .. -DSPM_ENABLE_SHARED=OFF -DCMAKE_INSTALL_PREFIX=./root
|
||||
% make install
|
||||
% cd ../python
|
||||
% python setup.py bdist_wheel
|
||||
% pip install dist/sentencepiece*.whl
|
||||
```
|
||||
|
||||
If you don’t have write permission to the global site-packages directory or don’t want to install into it, please try:
|
||||
```
|
||||
% python setup.py install --user
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
See [this google colab page](https://github.com/google/sentencepiece/blob/master/python/sentencepiece_python_module_example.ipynb) to run sentencepiece interactively.
|
||||
|
||||
### Segmentation
|
||||
```
|
||||
% python
|
||||
>>> import sentencepiece as spm
|
||||
>>> sp = spm.SentencePieceProcessor(model_file='test/test_model.model')
|
||||
|
||||
>>> sp.encode('This is a test')
|
||||
[284, 47, 11, 4, 15, 400]
|
||||
|
||||
>>> sp.encode(['This is a test', 'Hello world'], out_type=int)
|
||||
[[284, 47, 11, 4, 15, 400], [151, 88, 21, 887]]
|
||||
|
||||
>>> sp.encode_as_ids(['This is a test', 'Hello world'])
|
||||
[[284, 47, 11, 4, 15, 400], [151, 88, 21, 887]]
|
||||
|
||||
>>> sp.encode('This is a test', out_type=str)
|
||||
['▁This', '▁is', '▁a', '▁', 't', 'est']
|
||||
|
||||
>>> sp.encode(['This is a test', 'Hello world'], out_type=str)
|
||||
[['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁He', 'll', 'o', '▁world']]
|
||||
|
||||
>>> sp.encode_as_pieces(['This is a test', 'Hello world'])
|
||||
[['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁He', 'll', 'o', '▁world']]
|
||||
|
||||
>>> proto = sp.encode('This is a test', out_type='immutable_proto')
|
||||
>>> for n in proto.pieces:
|
||||
... print('piece="{}" surface="{}" id={} begin={} end={}'.format(n.piece, n.surface, n.id, n.begin, n.end))
|
||||
...
|
||||
piece="▁This" surface="This" id=284 begin=0 end=4
|
||||
piece="▁is" surface=" is" id=47 begin=4 end=7
|
||||
piece="▁a" surface=" a" id=11 begin=7 end=9
|
||||
piece="▁" surface=" " id=4 begin=9 end=10
|
||||
piece="t" surface="t" id=15 begin=10 end=11
|
||||
piece="est" surface="est" id=400 begin=11 end=14
|
||||
|
||||
>>> [[x.id for x in proto.pieces], [x.piece for x in proto.pieces], [x.begin for x in proto.pieces], [x.end for x in proto.pieces]]
|
||||
[[284, 47, 11, 4, 15, 400], ['▁This', '▁is', '▁a', '▁', 't', 'est'], [0, 4, 7, 9, 10, 11], [4, 7, 9, 10, 11, 14]]
|
||||
|
||||
>>> proto2 = sp.encode_as_immutable_proto('This is a test')
|
||||
>>> proto2 == proto
|
||||
True
|
||||
|
||||
>>> for _ in range(10):
|
||||
... sp.encode('This is a test', out_type=str, enable_sampling=True, alpha=0.1, nbest_size=-1)
|
||||
...
|
||||
['▁', 'This', '▁', 'is', '▁a', '▁', 't', 'e', 'st']
|
||||
['▁T', 'h', 'i', 's', '▁is', '▁a', '▁', 'te', 's', 't']
|
||||
['▁T', 'h', 'is', '▁', 'is', '▁', 'a', '▁', 't', 'est']
|
||||
['▁', 'This', '▁is', '▁', 'a', '▁', 't', 'e', 'st']
|
||||
['▁', 'This', '▁', 'is', '▁', 'a', '▁', 't', 'e', 's', 't']
|
||||
['▁This', '▁is', '▁a', '▁', 'te', 's', 't']
|
||||
['▁This', '▁is', '▁', 'a', '▁', 't', 'e', 'st']
|
||||
['▁', 'T', 'h', 'is', '▁', 'is', '▁', 'a', '▁', 'te', 'st']
|
||||
['▁', 'This', '▁', 'i', 's', '▁a', '▁', 't', 'e', 'st']
|
||||
['▁This', '▁', 'is', '▁a', '▁', 't', 'est']
|
||||
|
||||
>> sp.nbest_encode('This is a test', nbest_size=5, out_type=str)
|
||||
[['▁This', '▁is', '▁a', '▁', 't', 'est'],
|
||||
['▁This', '▁is', '▁a', '▁', 'te', 'st'],
|
||||
['▁This', '▁is', '▁a', '▁', 'te', 's', 't'],
|
||||
['▁This', '▁is', '▁a', '▁', 't', 'e', 'st'],
|
||||
['▁This', '▁is', '▁a', '▁', 't', 'es', 't']]
|
||||
|
||||
>>> sp.sample_encode_and_score('This is a test', num_samples=5, alpha=0.1, out_type=str, wor=True)
|
||||
[(['▁This', '▁', 'i', 's', '▁a', '▁', 'te', 's', 't'], -3.043105125427246),
|
||||
(['▁This', '▁', 'i', 's', '▁a', '▁', 'te', 'st'], -2.8475849628448486),
|
||||
(['▁', 'This', '▁is', '▁', 'a', '▁', 'te', 'st'], -3.043248176574707),
|
||||
(['▁', 'This', '▁is', '▁a', '▁', 't', 'e', 'st'], -2.87727689743042),
|
||||
(['▁', 'This', '▁', 'i', 's', '▁', 'a', '▁', 't', 'est'], -3.6284031867980957)]
|
||||
|
||||
>>> sp.decode([284, 47, 11, 4, 15, 400])
|
||||
'This is a test'
|
||||
|
||||
>>> sp.decode([[284, 47, 11, 4, 15, 400], [151, 88, 21, 887]])
|
||||
['This is a test', 'Hello world']
|
||||
|
||||
>>> proto = sp.decode([284, 47, 11, 4, 15, 400], out_type='immutable_proto')
|
||||
>>> proto.text
|
||||
'This is a test'
|
||||
|
||||
>>> sp.decode(['▁', 'This', '▁', 'is', '▁a', '▁', 't', 'e', 'st'])
|
||||
'This is a test'
|
||||
|
||||
>>> sp.decode([['▁This', '▁is', '▁a', '▁', 't', 'est'], ['▁He', 'll', 'o', '▁world']])
|
||||
['This is a test', 'Hello world']
|
||||
|
||||
>>> sp.get_piece_size()
|
||||
1000
|
||||
|
||||
>>> sp.id_to_piece(2)
|
||||
'</s>'
|
||||
|
||||
>>> sp.id_to_piece([2, 3, 4])
|
||||
['</s>', '\r', '▁']
|
||||
|
||||
>>> sp.piece_to_id('<s>')
|
||||
1
|
||||
|
||||
>>> sp.piece_to_id(['</s>', '\r', '▁'])
|
||||
[2, 3, 4]
|
||||
|
||||
>>> len(sp)
|
||||
1000
|
||||
|
||||
>>> sp['</s>']
|
||||
2
|
||||
```
|
||||
|
||||
### Model Training
|
||||
Training is performed by passing parameters of [spm_train](https://github.com/google/sentencepiece#train-sentencepiece-model) to SentencePieceTrainer.train() function.
|
||||
|
||||
```
|
||||
>>> import sentencepiece as spm
|
||||
>>> spm.SentencePieceTrainer.train(input='test/botchan.txt', model_prefix='m', vocab_size=1000, user_defined_symbols=['foo', 'bar'])
|
||||
sentencepiece_trainer.cc(73) LOG(INFO) Starts training with :
|
||||
trainer_spec {
|
||||
input: test/botchan.txt
|
||||
.. snip
|
||||
unigram_model_trainer.cc(500) LOG(INFO) EM sub_iter=1 size=1188 obj=10.2839 num_tokens=32182 num_tokens/piece=27.0892
|
||||
unigram_model_trainer.cc(500) LOG(INFO) EM sub_iter=0 size=1100 obj=10.4269 num_tokens=33001 num_tokens/piece=30.0009
|
||||
unigram_model_trainer.cc(500) LOG(INFO) EM sub_iter=1 size=1100 obj=10.4069 num_tokens=33002 num_tokens/piece=30.0018
|
||||
trainer_interface.cc(595) LOG(INFO) Saving model: m.model
|
||||
trainer_interface.cc(619) LOG(INFO) Saving vocabs: m.vocab
|
||||
>>>
|
||||
```
|
||||
|
||||
### Training without local filesystem
|
||||
Sentencepiece trainer can receive any iterable object to feed training sentences. You can also pass a file object (instance with write() method) to emit the output model to any devices. These features are useful to run sentencepiece on environment that have limited access to the local file system (e.g., Google colab.)
|
||||
|
||||
```
|
||||
import urllib.request
|
||||
import io
|
||||
import sentencepiece as spm
|
||||
|
||||
# Loads model from URL as iterator and stores the model to BytesIO.
|
||||
model = io.BytesIO()
|
||||
with urllib.request.urlopen(
|
||||
'https://raw.githubusercontent.com/google/sentencepiece/master/data/botchan.txt'
|
||||
) as response:
|
||||
spm.SentencePieceTrainer.train(
|
||||
sentence_iterator=response, model_writer=model, vocab_size=1000)
|
||||
|
||||
# Serialize the model as file.
|
||||
# with open('out.model', 'wb') as f:
|
||||
# f.write(model.getvalue())
|
||||
|
||||
# Directly load the model from serialized model.
|
||||
sp = spm.SentencePieceProcessor(model_proto=model.getvalue())
|
||||
print(sp.encode('this is test'))
|
||||
```
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### You can add new special tokens to pre-trained sentencepiece model\n",
|
||||
"#### Run this code in google/sentencepiece/python/src/sentencepiece"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load pre-trained sentencepiece model\n",
|
||||
"Pre-trained model is needed"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"371391"
|
||||
]
|
||||
},
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import sentencepiece_model_pb2 as model\n",
|
||||
"m = model.ModelProto()\n",
|
||||
"m.ParseFromString(open(\"old.model\", \"rb\").read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load tokens want to add\n",
|
||||
"Prepare the list of new tokens want to add"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"['[UNK]',\n",
|
||||
" '[PAD]',\n",
|
||||
" '[CLS]',\n",
|
||||
" '[SEP]',\n",
|
||||
" '[MASK]',\n",
|
||||
" '[EOS]',\n",
|
||||
" '[DOMAIN]',\n",
|
||||
" '[SLOT]',\n",
|
||||
" '[ACTION]']"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"special_tokens = open(\"special_tokens.txt\", \"r\").read().split(\"\\n\")\n",
|
||||
"special_tokens"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Add new tokens to sentencepiece model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for token in special_tokens:\n",
|
||||
" new_token = model.ModelProto().SentencePiece()\n",
|
||||
" new_token.piece = token\n",
|
||||
" new_token.score = 0\n",
|
||||
" m.pieces.append(new_token)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Save new sentencepiece model\n",
|
||||
"Load the new sentencepiece model to your NLP system"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open('new.model', 'wb') as f:\n",
|
||||
" f.write(m.SerializeToString())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.10"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
#!/bin/sh
|
||||
|
||||
VERSION="$1"
|
||||
|
||||
mkdir -p build
|
||||
|
||||
BUILD_DIR=./build
|
||||
INSTALL_DIR=./build/root
|
||||
|
||||
if [ -f ./sentencepiece/src/CMakeLists.txt ]; then
|
||||
SRC_DIR=./sentencepiece
|
||||
elif [ -f ../src/CMakeLists.txt ]; then
|
||||
SRC_DIR=..
|
||||
else
|
||||
# Try taged version. Othewise, use head.
|
||||
git clone https://github.com/google/sentencepiece.git -b v"${VERSION}" --depth 1 || \
|
||||
git clone https://github.com/google/sentencepiece.git --depth 1
|
||||
SRC_DIR=./sentencepiece
|
||||
fi
|
||||
|
||||
cmake ${SRC_DIR} -B ${BUILD_DIR} -DSPM_ENABLE_SHARED=OFF -DCMAKE_INSTALL_PREFIX=${INSTALL_DIR}
|
||||
cmake --build ${BUILD_DIR} --config Release --target install --parallel $(nproc)
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
#!/bin/sh
|
||||
|
||||
mkdir -p sentencepiece
|
||||
|
||||
for i in CMakeLists.txt LICENSE README.md VERSION.txt cmake config.h.in sentencepiece.pc.in src third_party
|
||||
do
|
||||
echo "copying ../${i} sentencepiece/${i}"
|
||||
cp -f -R "../${i}" sentencepiece
|
||||
done
|
||||
|
||||
python3 setup.py sdist
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,2 +0,0 @@
|
|||
[metadata]
|
||||
description_file = README.md
|
||||
|
|
@ -1,201 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2018 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.!
|
||||
|
||||
import codecs
|
||||
import os
|
||||
import string
|
||||
import subprocess
|
||||
import sys
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext as _build_ext
|
||||
from setuptools.command.build_py import build_py as _build_py
|
||||
|
||||
sys.path.append(os.path.join('.', 'test'))
|
||||
|
||||
|
||||
def long_description():
|
||||
with codecs.open('README.md', 'r', 'utf-8') as f:
|
||||
long_description = f.read()
|
||||
return long_description
|
||||
|
||||
|
||||
exec(open('src/sentencepiece/_version.py').read())
|
||||
|
||||
|
||||
def run_pkg_config(section, pkg_config_path=None):
|
||||
try:
|
||||
cmd = 'pkg-config sentencepiece --{}'.format(section)
|
||||
if pkg_config_path:
|
||||
cmd = 'env PKG_CONFIG_PATH={} {}'.format(pkg_config_path, cmd)
|
||||
output = subprocess.check_output(cmd, shell=True)
|
||||
if sys.version_info >= (3, 0, 0):
|
||||
output = output.decode('utf-8')
|
||||
except subprocess.CalledProcessError:
|
||||
sys.stderr.write('Failed to find sentencepiece pkg-config\n')
|
||||
sys.exit(1)
|
||||
return output.strip().split()
|
||||
|
||||
|
||||
def is_sentencepiece_installed():
|
||||
try:
|
||||
subprocess.check_call('pkg-config sentencepiece --libs', shell=True)
|
||||
return True
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
|
||||
|
||||
def get_cflags_and_libs(root):
|
||||
cflags = ['-std=c++17', '-I' + os.path.join(root, 'include')]
|
||||
libs = []
|
||||
if os.path.exists(os.path.join(root, 'lib/pkgconfig/sentencepiece.pc')):
|
||||
libs = [
|
||||
os.path.join(root, 'lib/libsentencepiece.a'),
|
||||
os.path.join(root, 'lib/libsentencepiece_train.a'),
|
||||
]
|
||||
elif os.path.exists(os.path.join(root, 'lib64/pkgconfig/sentencepiece.pc')):
|
||||
libs = [
|
||||
os.path.join(root, 'lib64/libsentencepiece.a'),
|
||||
os.path.join(root, 'lib64/libsentencepiece_train.a'),
|
||||
]
|
||||
return cflags, libs
|
||||
|
||||
|
||||
class build_ext(_build_ext):
|
||||
"""Override build_extension to run cmake."""
|
||||
|
||||
def build_extension(self, ext):
|
||||
cflags, libs = get_cflags_and_libs('../build/root')
|
||||
|
||||
if len(libs) == 0:
|
||||
if is_sentencepiece_installed():
|
||||
cflags = cflags + run_pkg_config('cflags')
|
||||
libs = run_pkg_config('libs')
|
||||
else:
|
||||
subprocess.check_call(['./build_bundled.sh', __version__])
|
||||
cflags, libs = get_cflags_and_libs('./build/root')
|
||||
|
||||
# Fix compile on some versions of Mac OSX
|
||||
# See: https://github.com/neulab/xnmt/issues/199
|
||||
if sys.platform == 'darwin':
|
||||
cflags.append('-mmacosx-version-min=10.9')
|
||||
else:
|
||||
cflags.append('-Wl,-strip-all')
|
||||
libs.append('-Wl,-strip-all')
|
||||
if sys.platform == 'linux':
|
||||
libs.append('-Wl,-Bsymbolic')
|
||||
print('## cflags={}'.format(' '.join(cflags)))
|
||||
print('## libs={}'.format(' '.join(libs)))
|
||||
ext.extra_compile_args = cflags
|
||||
ext.extra_link_args = libs
|
||||
_build_ext.build_extension(self, ext)
|
||||
|
||||
|
||||
if os.name == 'nt':
|
||||
# Must pre-install sentencepice into build directory.
|
||||
arch = 'win32'
|
||||
if sys.maxsize > 2**32:
|
||||
arch = 'amd64'
|
||||
if os.path.exists('..\\build\\root_{}\\lib'.format(arch)):
|
||||
cflags = ['/std:c++17', '/I..\\build\\root_{}\\include'.format(arch)]
|
||||
libs = [
|
||||
'..\\build\\root_{}\\lib\\sentencepiece.lib'.format(arch),
|
||||
'..\\build\\root_{}\\lib\\sentencepiece_train.lib'.format(arch),
|
||||
]
|
||||
elif os.path.exists('..\\build\\root\\lib'):
|
||||
cflags = ['/std:c++17', '/I..\\build\\root\\include']
|
||||
libs = [
|
||||
'..\\build\\root\\lib\\sentencepiece.lib',
|
||||
'..\\build\\root\\lib\\sentencepiece_train.lib',
|
||||
]
|
||||
else:
|
||||
# build library locally with cmake and vc++.
|
||||
cmake_arch = 'Win32'
|
||||
if arch == 'amd64':
|
||||
cmake_arch = 'x64'
|
||||
subprocess.check_call([
|
||||
'cmake',
|
||||
'sentencepiece',
|
||||
'-A',
|
||||
cmake_arch,
|
||||
'-B',
|
||||
'build',
|
||||
'-DSPM_ENABLE_SHARED=OFF',
|
||||
'-DCMAKE_INSTALL_PREFIX=build\\root',
|
||||
])
|
||||
subprocess.check_call([
|
||||
'cmake',
|
||||
'--build',
|
||||
'build',
|
||||
'--config',
|
||||
'Release',
|
||||
'--target',
|
||||
'install',
|
||||
'--parallel',
|
||||
'8',
|
||||
])
|
||||
cflags = ['/std:c++17', '/I.\\build\\root\\include']
|
||||
libs = [
|
||||
'.\\build\\root\\lib\\sentencepiece.lib',
|
||||
'.\\build\\root\\lib\\sentencepiece_train.lib',
|
||||
]
|
||||
|
||||
SENTENCEPIECE_EXT = Extension(
|
||||
'sentencepiece._sentencepiece',
|
||||
sources=['src/sentencepiece/sentencepiece_wrap.cxx'],
|
||||
extra_compile_args=cflags,
|
||||
extra_link_args=libs,
|
||||
)
|
||||
cmdclass = {}
|
||||
else:
|
||||
SENTENCEPIECE_EXT = Extension(
|
||||
'sentencepiece._sentencepiece',
|
||||
sources=['src/sentencepiece/sentencepiece_wrap.cxx'],
|
||||
)
|
||||
cmdclass = {'build_ext': build_ext}
|
||||
|
||||
setup(
|
||||
name='sentencepiece',
|
||||
author='Taku Kudo',
|
||||
author_email='taku@google.com',
|
||||
description='SentencePiece python wrapper',
|
||||
long_description=long_description(),
|
||||
long_description_content_type='text/markdown',
|
||||
version=__version__,
|
||||
package_dir={'': 'src'},
|
||||
url='https://github.com/google/sentencepiece',
|
||||
license='Apache',
|
||||
platforms='Unix',
|
||||
py_modules=[
|
||||
'sentencepiece/__init__',
|
||||
'sentencepiece/_version',
|
||||
'sentencepiece/sentencepiece_model_pb2',
|
||||
'sentencepiece/sentencepiece_pb2',
|
||||
],
|
||||
ext_modules=[SENTENCEPIECE_EXT],
|
||||
cmdclass=cmdclass,
|
||||
classifiers=[
|
||||
'Development Status :: 5 - Production/Stable',
|
||||
'Environment :: Console',
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: Unix',
|
||||
'Programming Language :: Python',
|
||||
'Topic :: Text Processing :: Linguistic',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
test_suite='sentencepiece_test.suite',
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1 +0,0 @@
|
|||
__version__ = '0.2.0'
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,44 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: sentencepiece_model.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19sentencepiece_model.proto\x12\rsentencepiece\"\x80\x0c\n\x0bTrainerSpec\x12\r\n\x05input\x18\x01 \x03(\t\x12\x14\n\x0cinput_format\x18\x07 \x01(\t\x12\x14\n\x0cmodel_prefix\x18\x02 \x01(\t\x12\x41\n\nmodel_type\x18\x03 \x01(\x0e\x32$.sentencepiece.TrainerSpec.ModelType:\x07UNIGRAM\x12\x18\n\nvocab_size\x18\x04 \x01(\x05:\x04\x38\x30\x30\x30\x12\x17\n\x0f\x61\x63\x63\x65pt_language\x18\x05 \x03(\t\x12 \n\x15self_test_sample_size\x18\x06 \x01(\x05:\x01\x30\x12*\n\x1b\x65nable_differential_privacy\x18\x32 \x01(\x08:\x05\x66\x61lse\x12+\n differential_privacy_noise_level\x18\x33 \x01(\x02:\x01\x30\x12\x32\n\'differential_privacy_clipping_threshold\x18\x34 \x01(\x04:\x01\x30\x12\"\n\x12\x63haracter_coverage\x18\n \x01(\x02:\x06\x30.9995\x12\x1e\n\x13input_sentence_size\x18\x0b \x01(\x04:\x01\x30\x12$\n\x16shuffle_input_sentence\x18\x13 \x01(\x08:\x04true\x12 \n\x14mining_sentence_size\x18\x0c \x01(\x05\x42\x02\x18\x01\x12\"\n\x16training_sentence_size\x18\r \x01(\x05\x42\x02\x18\x01\x12(\n\x17seed_sentencepiece_size\x18\x0e \x01(\x05:\x07\x31\x30\x30\x30\x30\x30\x30\x12\x1e\n\x10shrinking_factor\x18\x0f \x01(\x02:\x04\x30.75\x12!\n\x13max_sentence_length\x18\x12 \x01(\x05:\x04\x34\x31\x39\x32\x12\x17\n\x0bnum_threads\x18\x10 \x01(\x05:\x02\x31\x36\x12\x1d\n\x12num_sub_iterations\x18\x11 \x01(\x05:\x01\x32\x12$\n\x18max_sentencepiece_length\x18\x14 \x01(\x05:\x02\x31\x36\x12%\n\x17split_by_unicode_script\x18\x15 \x01(\x08:\x04true\x12\x1d\n\x0fsplit_by_number\x18\x17 \x01(\x08:\x04true\x12!\n\x13split_by_whitespace\x18\x16 \x01(\x08:\x04true\x12)\n\x1atreat_whitespace_as_suffix\x18\x18 \x01(\x08:\x05\x66\x61lse\x12+\n\x1c\x61llow_whitespace_only_pieces\x18\x1a \x01(\x08:\x05\x66\x61lse\x12\x1b\n\x0csplit_digits\x18\x19 \x01(\x08:\x05\x66\x61lse\x12#\n\x19pretokenization_delimiter\x18\x35 \x01(\t:\x00\x12\x17\n\x0f\x63ontrol_symbols\x18\x1e \x03(\t\x12\x1c\n\x14user_defined_symbols\x18\x1f \x03(\t\x12\x16\n\x0erequired_chars\x18$ \x01(\t\x12\x1c\n\rbyte_fallback\x18# \x01(\x08:\x05\x66\x61lse\x12+\n\x1dvocabulary_output_piece_score\x18 \x01(\x08:\x04true\x12\x1e\n\x10hard_vocab_limit\x18! \x01(\x08:\x04true\x12\x1c\n\ruse_all_vocab\x18\" \x01(\x08:\x05\x66\x61lse\x12\x11\n\x06unk_id\x18( \x01(\x05:\x01\x30\x12\x11\n\x06\x62os_id\x18) \x01(\x05:\x01\x31\x12\x11\n\x06\x65os_id\x18* \x01(\x05:\x01\x32\x12\x12\n\x06pad_id\x18+ \x01(\x05:\x02-1\x12\x18\n\tunk_piece\x18- \x01(\t:\x05<unk>\x12\x16\n\tbos_piece\x18. \x01(\t:\x03<s>\x12\x17\n\teos_piece\x18/ \x01(\t:\x04</s>\x12\x18\n\tpad_piece\x18\x30 \x01(\t:\x05<pad>\x12\x1a\n\x0bunk_surface\x18, \x01(\t:\x05 \xe2\x81\x87 \x12+\n\x1ctrain_extremely_large_corpus\x18\x31 \x01(\x08:\x05\x66\x61lse\"5\n\tModelType\x12\x0b\n\x07UNIGRAM\x10\x01\x12\x07\n\x03\x42PE\x10\x02\x12\x08\n\x04WORD\x10\x03\x12\x08\n\x04\x43HAR\x10\x04*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xd1\x01\n\x0eNormalizerSpec\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1c\n\x14precompiled_charsmap\x18\x02 \x01(\x0c\x12\x1e\n\x10\x61\x64\x64_dummy_prefix\x18\x03 \x01(\x08:\x04true\x12&\n\x18remove_extra_whitespaces\x18\x04 \x01(\x08:\x04true\x12 \n\x12\x65scape_whitespaces\x18\x05 \x01(\x08:\x04true\x12\x1e\n\x16normalization_rule_tsv\x18\x06 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"y\n\x0cSelfTestData\x12\x33\n\x07samples\x18\x01 \x03(\x0b\x32\".sentencepiece.SelfTestData.Sample\x1a)\n\x06Sample\x12\r\n\x05input\x18\x01 \x01(\t\x12\x10\n\x08\x65xpected\x18\x02 \x01(\t*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"\xfe\x03\n\nModelProto\x12\x37\n\x06pieces\x18\x01 \x03(\x0b\x32\'.sentencepiece.ModelProto.SentencePiece\x12\x30\n\x0ctrainer_spec\x18\x02 \x01(\x0b\x32\x1a.sentencepiece.TrainerSpec\x12\x36\n\x0fnormalizer_spec\x18\x03 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x12\x33\n\x0eself_test_data\x18\x04 \x01(\x0b\x32\x1b.sentencepiece.SelfTestData\x12\x38\n\x11\x64\x65normalizer_spec\x18\x05 \x01(\x0b\x32\x1d.sentencepiece.NormalizerSpec\x1a\xd2\x01\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\r\n\x05score\x18\x02 \x01(\x02\x12\x42\n\x04type\x18\x03 \x01(\x0e\x32,.sentencepiece.ModelProto.SentencePiece.Type:\x06NORMAL\"T\n\x04Type\x12\n\n\x06NORMAL\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x0b\n\x07\x43ONTROL\x10\x03\x12\x10\n\x0cUSER_DEFINED\x10\x04\x12\x08\n\x04\x42YTE\x10\x06\x12\n\n\x06UNUSED\x10\x05*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\x42\x02H\x03')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_model_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'H\003'
|
||||
_TRAINERSPEC.fields_by_name['mining_sentence_size']._options = None
|
||||
_TRAINERSPEC.fields_by_name['mining_sentence_size']._serialized_options = b'\030\001'
|
||||
_TRAINERSPEC.fields_by_name['training_sentence_size']._options = None
|
||||
_TRAINERSPEC.fields_by_name['training_sentence_size']._serialized_options = b'\030\001'
|
||||
_TRAINERSPEC._serialized_start=45
|
||||
_TRAINERSPEC._serialized_end=1581
|
||||
_TRAINERSPEC_MODELTYPE._serialized_start=1517
|
||||
_TRAINERSPEC_MODELTYPE._serialized_end=1570
|
||||
_NORMALIZERSPEC._serialized_start=1584
|
||||
_NORMALIZERSPEC._serialized_end=1793
|
||||
_SELFTESTDATA._serialized_start=1795
|
||||
_SELFTESTDATA._serialized_end=1916
|
||||
_SELFTESTDATA_SAMPLE._serialized_start=1864
|
||||
_SELFTESTDATA_SAMPLE._serialized_end=1905
|
||||
_MODELPROTO._serialized_start=1919
|
||||
_MODELPROTO._serialized_end=2429
|
||||
_MODELPROTO_SENTENCEPIECE._serialized_start=2208
|
||||
_MODELPROTO_SENTENCEPIECE._serialized_end=2418
|
||||
_MODELPROTO_SENTENCEPIECE_TYPE._serialized_start=2323
|
||||
_MODELPROTO_SENTENCEPIECE_TYPE._serialized_end=2407
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
|
|
@ -1,30 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
# source: sentencepiece.proto
|
||||
"""Generated protocol buffer code."""
|
||||
from google.protobuf.internal import builder as _builder
|
||||
from google.protobuf import descriptor as _descriptor
|
||||
from google.protobuf import descriptor_pool as _descriptor_pool
|
||||
from google.protobuf import symbol_database as _symbol_database
|
||||
# @@protoc_insertion_point(imports)
|
||||
|
||||
_sym_db = _symbol_database.Default()
|
||||
|
||||
|
||||
|
||||
|
||||
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x13sentencepiece.proto\x12\rsentencepiece\"\xdf\x01\n\x11SentencePieceText\x12\x0c\n\x04text\x18\x01 \x01(\t\x12>\n\x06pieces\x18\x02 \x03(\x0b\x32..sentencepiece.SentencePieceText.SentencePiece\x12\r\n\x05score\x18\x03 \x01(\x02\x1a\x62\n\rSentencePiece\x12\r\n\x05piece\x18\x01 \x01(\t\x12\n\n\x02id\x18\x02 \x01(\r\x12\x0f\n\x07surface\x18\x03 \x01(\t\x12\r\n\x05\x62\x65gin\x18\x04 \x01(\r\x12\x0b\n\x03\x65nd\x18\x05 \x01(\r*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02*\t\x08\xc8\x01\x10\x80\x80\x80\x80\x02\"J\n\x16NBestSentencePieceText\x12\x30\n\x06nbests\x18\x01 \x03(\x0b\x32 .sentencepiece.SentencePieceTextB\x02H\x03')
|
||||
|
||||
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
|
||||
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'sentencepiece_pb2', globals())
|
||||
if _descriptor._USE_C_DESCRIPTORS == False:
|
||||
|
||||
DESCRIPTOR._options = None
|
||||
DESCRIPTOR._serialized_options = b'H\003'
|
||||
_SENTENCEPIECETEXT._serialized_start=39
|
||||
_SENTENCEPIECETEXT._serialized_end=262
|
||||
_SENTENCEPIECETEXT_SENTENCEPIECE._serialized_start=153
|
||||
_SENTENCEPIECETEXT_SENTENCEPIECE._serialized_end=251
|
||||
_NBESTSENTENCEPIECETEXT._serialized_start=264
|
||||
_NBESTSENTENCEPIECETEXT._serialized_end=338
|
||||
# @@protoc_insertion_point(module_scope)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1 +0,0 @@
|
|||
../../data/botchan.txt
|
||||
|
|
@ -1,934 +0,0 @@
|
|||
#!/usr/bin/python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2018 Google Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.!
|
||||
|
||||
from collections import defaultdict
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import unittest
|
||||
import sentencepiece as spm
|
||||
|
||||
print('VERSION={}'.format(spm.__version__))
|
||||
|
||||
data_dir = 'test'
|
||||
if sys.platform == 'win32':
|
||||
data_dir = os.path.join('..', 'data')
|
||||
|
||||
|
||||
class TestSentencepieceProcessor(unittest.TestCase):
|
||||
"""Test case for SentencePieceProcessor"""
|
||||
|
||||
def setUp(self):
|
||||
self.sp_ = spm.SentencePieceProcessor()
|
||||
self.jasp_ = spm.SentencePieceProcessor()
|
||||
self.assertTrue(self.sp_.Load(os.path.join('test', 'test_model.model')))
|
||||
self.assertTrue(
|
||||
self.jasp_.Load(os.path.join('test', 'test_ja_model.model'))
|
||||
)
|
||||
with open(os.path.join('test', 'test_model.model'), 'rb') as f:
|
||||
self.assertTrue(self.sp_.LoadFromSerializedProto(f.read()))
|
||||
with open(os.path.join('test', 'test_ja_model.model'), 'rb') as f:
|
||||
self.assertTrue(self.jasp_.LoadFromSerializedProto(f.read()))
|
||||
|
||||
def test_load(self):
|
||||
self.assertEqual(1000, self.sp_.GetPieceSize())
|
||||
self.assertEqual(0, self.sp_.PieceToId('<unk>'))
|
||||
self.assertEqual(1, self.sp_.PieceToId('<s>'))
|
||||
self.assertEqual(2, self.sp_.PieceToId('</s>'))
|
||||
self.assertEqual('<unk>', self.sp_.IdToPiece(0))
|
||||
self.assertEqual('<s>', self.sp_.IdToPiece(1))
|
||||
self.assertEqual('</s>', self.sp_.IdToPiece(2))
|
||||
self.assertEqual(0, self.sp_.unk_id())
|
||||
self.assertEqual(1, self.sp_.bos_id())
|
||||
self.assertEqual(2, self.sp_.eos_id())
|
||||
self.assertEqual(-1, self.sp_.pad_id())
|
||||
for i in range(self.sp_.GetPieceSize()):
|
||||
piece = self.sp_.IdToPiece(i)
|
||||
self.assertEqual(i, self.sp_.PieceToId(piece))
|
||||
|
||||
self.assertEqual(1000, self.sp_.get_piece_size())
|
||||
self.assertEqual(0, self.sp_.piece_to_id('<unk>'))
|
||||
self.assertEqual(1, self.sp_.piece_to_id('<s>'))
|
||||
self.assertEqual(2, self.sp_.piece_to_id('</s>'))
|
||||
self.assertEqual('<unk>', self.sp_.id_to_piece(0))
|
||||
self.assertEqual('<s>', self.sp_.id_to_piece(1))
|
||||
self.assertEqual('</s>', self.sp_.id_to_piece(2))
|
||||
for i in range(self.sp_.get_piece_size()):
|
||||
piece = self.sp_.id_to_piece(i)
|
||||
self.assertEqual(i, self.sp_.piece_to_id(piece))
|
||||
|
||||
def test_roundtrip(self):
|
||||
text = 'I saw a girl with a telescope.'
|
||||
ids = self.sp_.EncodeAsIds(text)
|
||||
pieces1 = self.sp_.EncodeAsPieces(text)
|
||||
pieces2 = self.sp_.NBestEncodeAsPieces(text, 10)[0]
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.sp_.DecodePieces(pieces1))
|
||||
self.assertEqual(text, self.sp_.DecodeIds(ids))
|
||||
for n in range(100):
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, 64, 0.5)),
|
||||
)
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.sp_.DecodePieces(self.sp_.SampleEncodeAsPieces(text, -1, 0.5)),
|
||||
)
|
||||
self.assertEqual(
|
||||
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, 64, 0.5))
|
||||
)
|
||||
self.assertEqual(
|
||||
text, self.sp_.DecodeIds(self.sp_.SampleEncodeAsIds(text, -1, 0.5))
|
||||
)
|
||||
|
||||
ids2 = self.sp_.encode_as_ids(text)
|
||||
pieces3 = self.sp_.encode_as_pieces(text)
|
||||
pieces4 = self.sp_.nbest_encode_as_pieces(text, 10)[0]
|
||||
self.assertEqual(pieces3, pieces4)
|
||||
self.assertEqual(pieces1, pieces3)
|
||||
self.assertEqual(ids, ids2)
|
||||
self.assertEqual(text, self.sp_.decode_pieces(pieces3))
|
||||
self.assertEqual(text, self.sp_.decode_ids(ids2))
|
||||
for n in range(100):
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.sp_.decode_pieces(
|
||||
self.sp_.sample_encode_as_pieces(text, 64, 0.5)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.sp_.decode_pieces(
|
||||
self.sp_.sample_encode_as_pieces(text, -1, 0.5)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, 64, 0.5)),
|
||||
)
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.sp_.decode_ids(self.sp_.sample_encode_as_ids(text, -1, 0.5)),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
self.sp_.calculate_entropy(text, 0.1),
|
||||
self.sp_.CalculateEntropy(text, 0.1),
|
||||
)
|
||||
|
||||
def test_ja_load(self):
|
||||
self.assertEqual(8000, self.jasp_.GetPieceSize())
|
||||
self.assertEqual(0, self.jasp_.PieceToId('<unk>'))
|
||||
self.assertEqual(1, self.jasp_.PieceToId('<s>'))
|
||||
self.assertEqual(2, self.jasp_.PieceToId('</s>'))
|
||||
self.assertEqual('<unk>', self.jasp_.IdToPiece(0))
|
||||
self.assertEqual('<s>', self.jasp_.IdToPiece(1))
|
||||
self.assertEqual('</s>', self.jasp_.IdToPiece(2))
|
||||
for i in range(self.jasp_.GetPieceSize()):
|
||||
piece = self.jasp_.IdToPiece(i)
|
||||
self.assertEqual(i, self.jasp_.PieceToId(piece))
|
||||
|
||||
self.assertEqual(8000, self.jasp_.get_piece_size())
|
||||
self.assertEqual(0, self.jasp_.piece_to_id('<unk>'))
|
||||
self.assertEqual(1, self.jasp_.piece_to_id('<s>'))
|
||||
self.assertEqual(2, self.jasp_.piece_to_id('</s>'))
|
||||
self.assertEqual('<unk>', self.jasp_.id_to_piece(0))
|
||||
self.assertEqual('<s>', self.jasp_.id_to_piece(1))
|
||||
self.assertEqual('</s>', self.jasp_.id_to_piece(2))
|
||||
for i in range(self.jasp_.get_piece_size()):
|
||||
piece = self.jasp_.id_to_piece(i)
|
||||
self.assertEqual(i, self.jasp_.piece_to_id(piece))
|
||||
|
||||
def test_ja_roundtrip(self):
|
||||
text = '清水寺は京都にある。'
|
||||
ids = self.jasp_.EncodeAsIds(text)
|
||||
pieces1 = self.jasp_.EncodeAsPieces(text)
|
||||
pieces2 = self.jasp_.NBestEncodeAsPieces(text, 10)[0]
|
||||
self.assertEqual(pieces1, pieces2)
|
||||
self.assertEqual(text, self.jasp_.DecodePieces(pieces1))
|
||||
self.assertEqual(text, self.jasp_.DecodeIds(ids))
|
||||
for n in range(100):
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.jasp_.DecodePieces(
|
||||
self.jasp_.SampleEncodeAsPieces(text, 64, 0.5)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.jasp_.DecodePieces(
|
||||
self.jasp_.SampleEncodeAsPieces(text, -1, 0.5)
|
||||
),
|
||||
)
|
||||
|
||||
ids2 = self.jasp_.encode_as_ids(text)
|
||||
pieces3 = self.jasp_.encode_as_pieces(text)
|
||||
pieces4 = self.jasp_.nbest_encode_as_pieces(text, 10)[0]
|
||||
self.assertEqual(pieces3, pieces4)
|
||||
self.assertEqual(pieces1, pieces3)
|
||||
self.assertEqual(ids, ids2)
|
||||
self.assertEqual(text, self.jasp_.decode_pieces(pieces1))
|
||||
self.assertEqual(text, self.jasp_.decode_ids(ids2))
|
||||
for n in range(100):
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.jasp_.decode_pieces(
|
||||
self.jasp_.sample_encode_as_pieces(text, 64, 0.5)
|
||||
),
|
||||
)
|
||||
self.assertEqual(
|
||||
text,
|
||||
self.jasp_.decode_pieces(
|
||||
self.jasp_.sample_encode_as_pieces(text, -1, 0.5)
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
self.jasp_.calculate_entropy(text, 0.1),
|
||||
self.jasp_.CalculateEntropy(text, 0.1),
|
||||
)
|
||||
|
||||
def test_train(self):
|
||||
spm.SentencePieceTrainer.Train(
|
||||
'--input='
|
||||
+ os.path.join(data_dir, 'botchan.txt')
|
||||
+ ' --model_prefix=m --vocab_size=1000'
|
||||
)
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.Load('m.model')
|
||||
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
|
||||
for line in file:
|
||||
sp.DecodePieces(sp.EncodeAsPieces(line))
|
||||
sp.DecodeIds(sp.EncodeAsIds(line))
|
||||
|
||||
def test_train_iterator(self):
|
||||
spm.SentencePieceTrainer.Train(
|
||||
'--input='
|
||||
+ os.path.join(data_dir, 'botchan.txt')
|
||||
+ ' --model_prefix=m --vocab_size=1000'
|
||||
)
|
||||
# Load as 'rb' for Python3.5/2.7.
|
||||
os1 = io.BytesIO()
|
||||
os2 = io.BytesIO()
|
||||
|
||||
# suppress logging (redirect to /dev/null)
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=os.path.join(data_dir, 'botchan.txt'),
|
||||
model_prefix='m',
|
||||
vocab_size=1000,
|
||||
logstream=open(os.devnull, 'w'),
|
||||
)
|
||||
|
||||
with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is1:
|
||||
spm.SentencePieceTrainer.train(
|
||||
sentence_iterator=is1,
|
||||
model_prefix='m',
|
||||
vocab_size=1000,
|
||||
logstream=open(os.devnull, 'w'),
|
||||
)
|
||||
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=os.path.join(data_dir, 'botchan.txt'),
|
||||
model_writer=os1,
|
||||
vocab_size=1000,
|
||||
logstream=open(os.devnull, 'w'),
|
||||
)
|
||||
|
||||
with open(os.path.join(data_dir, 'botchan.txt'), 'rb') as is2:
|
||||
spm.SentencePieceTrainer.train(
|
||||
sentence_iterator=is2,
|
||||
model_writer=os2,
|
||||
vocab_size=1000,
|
||||
logstream=open(os.devnull, 'w'),
|
||||
)
|
||||
|
||||
sp1 = spm.SentencePieceProcessor(model_proto=os1.getvalue())
|
||||
sp2 = spm.SentencePieceProcessor(model_proto=os2.getvalue())
|
||||
self.assertEqual(
|
||||
[sp1.id_to_piece(i) for i in range(sp1.get_piece_size())],
|
||||
[sp2.id_to_piece(i) for i in range(sp2.get_piece_size())],
|
||||
)
|
||||
|
||||
def test_train_kwargs(self):
|
||||
# suppress logging (redirect to /dev/null)
|
||||
spm.SentencePieceTrainer.train(
|
||||
input=[os.path.join(data_dir, 'botchan.txt')],
|
||||
model_prefix='m',
|
||||
vocab_size=1002,
|
||||
user_defined_symbols=['foo', 'bar', ',', ' ', '\t', '\b', '\n', '\r'],
|
||||
logstream=open(os.devnull, 'w'),
|
||||
)
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.Load('m.model')
|
||||
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
|
||||
for line in file:
|
||||
sp.DecodePieces(sp.EncodeAsPieces(line))
|
||||
sp.DecodeIds(sp.EncodeAsIds(line))
|
||||
|
||||
s = 'hello\tworld\r\nthis\tis a \b pen'
|
||||
self.assertEqual(s, sp.decode(sp.encode(s)))
|
||||
|
||||
def test_serialized_proto(self):
|
||||
text = 'I saw a girl with a telescope.'
|
||||
s1 = self.sp_.EncodeAsSerializedProto(text)
|
||||
s2 = self.sp_.SampleEncodeAsSerializedProto(text, 10, 0.2)
|
||||
s3 = self.sp_.NBestEncodeAsSerializedProto(text, 10)
|
||||
s4 = self.sp_.DecodePiecesAsSerializedProto(['foo', 'bar'])
|
||||
s5 = self.sp_.DecodeIdsAsSerializedProto([20, 30])
|
||||
|
||||
t1 = self.sp_.encode_as_serialized_proto(text)
|
||||
t2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2)
|
||||
t3 = self.sp_.nbest_encode_as_serialized_proto(text, 10)
|
||||
t4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar'])
|
||||
t5 = self.sp_.decode_ids_as_serialized_proto([20, 30])
|
||||
|
||||
y1 = self.sp_.encode(text, out_type='serialized_proto')
|
||||
y2 = self.sp_.encode(
|
||||
text, enable_sampling=True, out_type='serialized_proto'
|
||||
)
|
||||
y3 = self.sp_.nbest_encode(text, out_type='serialized_proto', nbest_size=10)
|
||||
y4 = self.sp_.decode(['foo', 'bar'], out_type='serialized_proto')
|
||||
y5 = self.sp_.decode([20, 30], out_type='serialized_proto')
|
||||
|
||||
self.assertEqual(type(s1), bytes)
|
||||
self.assertEqual(type(s2), bytes)
|
||||
self.assertEqual(type(t2), bytes)
|
||||
self.assertEqual(type(s3), bytes)
|
||||
self.assertEqual(type(s4), bytes)
|
||||
self.assertEqual(type(s5), bytes)
|
||||
|
||||
self.assertEqual(s1, t1)
|
||||
self.assertEqual(s3, t3)
|
||||
self.assertEqual(s4, t4)
|
||||
self.assertEqual(s5, t5)
|
||||
self.assertEqual(s1, y1)
|
||||
self.assertEqual(s3, y3)
|
||||
self.assertEqual(s4, y4)
|
||||
self.assertEqual(s5, y5)
|
||||
|
||||
ids = self.jasp_.EncodeAsIds(text)
|
||||
pieces = self.jasp_.EncodeAsPieces(text)
|
||||
s1 = self.jasp_.EncodeAsSerializedProto(text)
|
||||
s2 = self.jasp_.DecodeIdsAsSerializedProto(ids)
|
||||
s3 = self.jasp_.DecodePiecesAsSerializedProto(ids)
|
||||
self.assertEqual(s2, s1)
|
||||
self.assertEqual(s3, s1)
|
||||
|
||||
def test_decode_bytes(self):
|
||||
texts = ['Hello world', '清水寺は京都にある。']
|
||||
ids = self.jasp_.encode(texts, out_type=int)
|
||||
s1 = self.jasp_.decode(ids, out_type=bytes)
|
||||
s2 = self.jasp_.decode(ids, out_type=str)
|
||||
self.assertEqual(len(s1), 2)
|
||||
self.assertEqual(type(s1[0]), bytes)
|
||||
self.assertEqual(type(s1[1]), bytes)
|
||||
self.assertEqual(len(s2), 2)
|
||||
self.assertEqual(type(s2[0]), str)
|
||||
self.assertEqual(type(s2[1]), str)
|
||||
self.assertEqual(s1[0].decode(encoding='utf-8'), s2[0])
|
||||
self.assertEqual(s1[1].decode(encoding='utf-8'), s2[1])
|
||||
|
||||
text = 'Hello world'
|
||||
ids = self.jasp_.encode(text, out_type=int)
|
||||
s1 = self.jasp_.decode(ids, out_type=bytes)
|
||||
s2 = self.jasp_.decode(ids, out_type=str)
|
||||
self.assertEqual(type(s1), bytes)
|
||||
self.assertEqual(type(s2), str)
|
||||
self.assertEqual(s1.decode(encoding='utf-8'), s2)
|
||||
|
||||
x = self.jasp_.encode(text, out_type='immutable_proto')
|
||||
self.assertEqual(x.text, x.text_as_bytes.decode(encoding='utf-8'))
|
||||
for sp in x.pieces:
|
||||
self.assertEqual(sp.piece, sp.piece_as_bytes.decode(encoding='utf-8'))
|
||||
self.assertEqual(sp.surface, sp.surface_as_bytes.decode(encoding='utf-8'))
|
||||
|
||||
x = self.jasp_.decode(ids, out_type='immutable_proto')
|
||||
self.assertEqual(x.text, x.text_as_bytes.decode(encoding='utf-8'))
|
||||
for sp in x.pieces:
|
||||
self.assertEqual(sp.piece, sp.piece_as_bytes.decode(encoding='utf-8'))
|
||||
self.assertEqual(sp.surface, sp.surface_as_bytes.decode(encoding='utf-8'))
|
||||
|
||||
def test_immutable_proto(self):
|
||||
text = 'I saw a girl with a telescope.'
|
||||
s1 = self.sp_.EncodeAsImmutableProto(text)
|
||||
s2 = self.sp_.SampleEncodeAsImmutableProto(text, 10, 0.2)
|
||||
s3 = self.sp_.NBestEncodeAsImmutableProto(text, 10)
|
||||
s4 = self.sp_.DecodePiecesAsImmutableProto(['foo', 'bar'])
|
||||
s5 = self.sp_.DecodeIdsAsImmutableProto([20, 30])
|
||||
|
||||
print(s1)
|
||||
print(s2)
|
||||
print(s3)
|
||||
print(s4)
|
||||
print(s5)
|
||||
|
||||
t1 = self.sp_.encode_as_immutable_proto(text)
|
||||
t2 = self.sp_.sample_encode_as_immutable_proto(text, 10, 0.2)
|
||||
t3 = self.sp_.nbest_encode_as_immutable_proto(text, 10)
|
||||
t4 = self.sp_.decode_pieces_as_immutable_proto(['foo', 'bar'])
|
||||
t5 = self.sp_.decode_ids_as_immutable_proto([20, 30])
|
||||
|
||||
y1 = self.sp_.encode(text, out_type='immutable_proto')
|
||||
y2 = self.sp_.encode(text, enable_sampling=True, out_type='immutable_proto')
|
||||
y3 = self.sp_.nbest_encode(text, out_type='immutable_proto', nbest_size=10)
|
||||
y4 = self.sp_.decode(['foo', 'bar'], out_type='immutable_proto')
|
||||
y5 = self.sp_.decode([20, 30], out_type='immutable_proto')
|
||||
|
||||
self.assertEqual(s1, t1)
|
||||
self.assertEqual(s3, t3)
|
||||
self.assertEqual(s4, t4)
|
||||
self.assertEqual(s5, t5)
|
||||
self.assertEqual(s1, y1)
|
||||
self.assertEqual(s3, y3)
|
||||
self.assertEqual(s4, y4)
|
||||
self.assertEqual(s5, y5)
|
||||
|
||||
hset_piece = defaultdict(int)
|
||||
|
||||
# eq test
|
||||
for i in range(len(s1.pieces)):
|
||||
self.assertEqual(s1.pieces[i], t1.pieces[i])
|
||||
hset_piece[s1.pieces[i]] += 1
|
||||
hset_piece[t1.pieces[i]] += 1
|
||||
|
||||
self.assertEqual(len(hset_piece), len(s1.pieces))
|
||||
|
||||
# has test
|
||||
hset = defaultdict(int)
|
||||
hset[s1] += 1
|
||||
hset[t1] += 1
|
||||
hset[s3] += 1
|
||||
hset[t3] += 1
|
||||
|
||||
self.assertEqual(len(hset), 2)
|
||||
self.assertEqual(hset[s1], 2)
|
||||
self.assertEqual(hset[s3], 2)
|
||||
self.assertEqual(hset[t1], 2)
|
||||
self.assertEqual(hset[t3], 2)
|
||||
|
||||
x1 = self.sp_.encode_as_serialized_proto(text)
|
||||
x2 = self.sp_.sample_encode_as_serialized_proto(text, 10, 0.2)
|
||||
x3 = self.sp_.nbest_encode_as_serialized_proto(text, 10)
|
||||
x4 = self.sp_.decode_pieces_as_serialized_proto(['foo', 'bar'])
|
||||
x5 = self.sp_.decode_ids_as_serialized_proto([20, 30])
|
||||
|
||||
self.assertEqual(x1, t1.SerializeAsString())
|
||||
self.assertEqual(x3, t3.SerializeAsString())
|
||||
self.assertEqual(x4, t4.SerializeAsString())
|
||||
self.assertEqual(x5, t5.SerializeAsString())
|
||||
|
||||
v1 = self.sp_.EncodeAsIds(text)
|
||||
v2 = self.sp_.EncodeAsPieces(text)
|
||||
self.assertEqual([x.id for x in s1.pieces], v1)
|
||||
self.assertEqual([x.piece for x in s1.pieces], v2)
|
||||
self.assertEqual(text, s1.text)
|
||||
|
||||
surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces]
|
||||
surfaces2 = [x.surface for x in s1.pieces]
|
||||
self.assertEqual(surfaces1, surfaces2)
|
||||
|
||||
ids = []
|
||||
for i in range(len(s1.pieces)):
|
||||
ids.append(s1.pieces[i].id)
|
||||
self.assertEqual(ids, v1)
|
||||
|
||||
pieces = []
|
||||
for i in range(len(s1.pieces)):
|
||||
pieces.append(s1.pieces[i].piece)
|
||||
self.assertEqual(pieces, v2)
|
||||
|
||||
for v in s3.nbests:
|
||||
self.assertEqual(text, v.text)
|
||||
self.assertEqual(self.sp_.Decode([x.id for x in v.pieces]), text)
|
||||
|
||||
for i in range(len(s3.nbests)):
|
||||
self.assertEqual(text, s3.nbests[i].text)
|
||||
self.assertEqual(
|
||||
self.sp_.Decode([x.id for x in s3.nbests[i].pieces]), text
|
||||
)
|
||||
|
||||
# slice
|
||||
self.assertEqual(s1.pieces[::-1], list(reversed(s1.pieces)))
|
||||
self.assertEqual(s3.nbests[::-1], list(reversed(s3.nbests)))
|
||||
|
||||
# Japanese offset
|
||||
s1 = self.jasp_.EncodeAsImmutableProto(
|
||||
'吾輩は猫である。Hello world. ABC 123'
|
||||
)
|
||||
surfaces1 = [s1.text[x.begin : x.end] for x in s1.pieces]
|
||||
surfaces2 = [x.surface for x in s1.pieces]
|
||||
self.assertEqual(surfaces1, surfaces2)
|
||||
|
||||
ids = [x.id for x in s1.pieces]
|
||||
s2 = self.jasp_.DecodeIdsAsImmutableProto(ids)
|
||||
self.assertEqual(s2, s1)
|
||||
|
||||
pieces = [x.piece for x in s1.pieces]
|
||||
s2 = self.jasp_.DecodePiecesAsImmutableProto(pieces)
|
||||
self.assertEqual(s2, s1)
|
||||
|
||||
def test_new_api(self):
|
||||
sp = spm.SentencePieceProcessor(
|
||||
model_file=os.path.join('test', 'test_model.model')
|
||||
)
|
||||
text = 'hello world'
|
||||
text2 = 'Tokyo'
|
||||
ids = self.sp_.EncodeAsIds(text)
|
||||
ids2 = self.sp_.EncodeAsIds(text2)
|
||||
pieces = self.sp_.EncodeAsPieces(text)
|
||||
pieces2 = self.sp_.EncodeAsPieces(text2)
|
||||
sprotos = self.sp_.EncodeAsSerializedProto(text)
|
||||
sproto2 = self.sp_.EncodeAsSerializedProto(text2)
|
||||
iprotos = self.sp_.EncodeAsImmutableProto(text)
|
||||
iprotos2 = self.sp_.EncodeAsImmutableProto(text2)
|
||||
|
||||
self.assertEqual(sp.encode(text, out_type=int), ids)
|
||||
self.assertEqual(sp.encode(text, out_type=str), pieces)
|
||||
self.assertEqual(sp.encode(text, out_type='serialized_proto'), sprotos)
|
||||
self.assertEqual(sp.encode(text, out_type='immutable_proto'), iprotos)
|
||||
|
||||
self.assertEqual(sp.encode([text], out_type=int), [ids])
|
||||
self.assertEqual(sp.encode([text], out_type=str), [pieces])
|
||||
self.assertEqual(sp.encode([text], out_type='serialized_proto'), [sprotos])
|
||||
self.assertEqual(sp.encode([text], out_type='immutable_proto'), [iprotos])
|
||||
|
||||
self.assertEqual(len(iprotos.pieces), len(pieces))
|
||||
self.assertEqual(len(iprotos.pieces), len(ids))
|
||||
self.assertEqual(iprotos.text, text)
|
||||
|
||||
self.assertEqual(len(iprotos2.pieces), len(pieces2))
|
||||
self.assertEqual(len(iprotos2.pieces), len(ids2))
|
||||
self.assertEqual(iprotos2.text, text2)
|
||||
|
||||
for i in range(len(iprotos.pieces)):
|
||||
self.assertEqual(ids[i], iprotos.pieces[i].id)
|
||||
self.assertEqual(pieces[i], iprotos.pieces[i].piece)
|
||||
|
||||
for i, piece in enumerate(iprotos.pieces):
|
||||
self.assertEqual(ids[i], piece.id)
|
||||
self.assertEqual(pieces[i], piece.piece)
|
||||
|
||||
for i in range(len(iprotos2.pieces)):
|
||||
self.assertEqual(ids2[i], iprotos2.pieces[i].id)
|
||||
self.assertEqual(pieces2[i], iprotos2.pieces[i].piece)
|
||||
|
||||
for i, piece in enumerate(iprotos2.pieces):
|
||||
self.assertEqual(ids2[i], piece.id)
|
||||
self.assertEqual(pieces2[i], piece.piece)
|
||||
|
||||
detok_ids = self.sp_.DecodeIds(ids)
|
||||
detok_pieces = self.sp_.DecodePieces(pieces)
|
||||
self.assertEqual(sp.decode(ids), detok_ids)
|
||||
self.assertEqual(sp.decode(pieces), detok_pieces)
|
||||
self.assertEqual(sp.decode([]), '')
|
||||
self.assertEqual(sp.decode([[]]), [''])
|
||||
|
||||
# add_bos, add_eos, reverse
|
||||
self.assertEqual([sp.bos_id()] + ids, sp.encode(text, add_bos=True))
|
||||
self.assertEqual(ids + [sp.eos_id()], sp.encode(text, add_eos=True))
|
||||
self.assertEqual(ids + [sp.eos_id()], sp.EncodeAsIds(text, add_eos=True))
|
||||
rids = ids[:]
|
||||
rids.reverse()
|
||||
|
||||
self.assertEqual(rids, sp.encode(text, reverse=True))
|
||||
self.assertEqual(rids, sp.EncodeAsIds(text, reverse=True))
|
||||
|
||||
# different shape.
|
||||
self.assertEqual([ids, ids2], sp.encode([text, text2]))
|
||||
self.assertEqual([pieces, pieces2], sp.encode([text, text2], out_type=str))
|
||||
self.assertEqual([text, text2], sp.decode([ids, ids2]))
|
||||
self.assertEqual([text, text2], sp.decode([pieces, pieces2]))
|
||||
|
||||
pieces = list(reversed(self.sp_.EncodeAsPieces(text)))
|
||||
self.assertEqual(pieces, sp.encode(text, reverse=True, out_type=str))
|
||||
|
||||
# emit unk piece
|
||||
unk_char = '藤'
|
||||
pieces = self.sp_.EncodeAsIds(unk_char, emit_unk_piece=True)
|
||||
pieces2 = self.sp_.encode(unk_char, out_type=int, emit_unk_piece=True)
|
||||
self.assertEqual(pieces[1], sp.unk_id())
|
||||
self.assertEqual(pieces2[1], sp.unk_id())
|
||||
self.assertEqual(pieces, pieces2)
|
||||
|
||||
pieces = self.sp_.EncodeAsPieces(unk_char, emit_unk_piece=True)
|
||||
pieces2 = self.sp_.encode(unk_char, out_type=str, emit_unk_piece=True)
|
||||
self.assertEqual(pieces[1], '<unk>')
|
||||
self.assertEqual(pieces2[1], '<unk>')
|
||||
self.assertEqual(pieces, pieces2)
|
||||
|
||||
pieces = self.sp_.EncodeAsPieces(unk_char, emit_unk_piece=False)
|
||||
pieces2 = self.sp_.encode(unk_char, out_type=str, emit_unk_piece=False)
|
||||
self.assertEqual(pieces[1], unk_char)
|
||||
self.assertEqual(pieces2[1], unk_char)
|
||||
self.assertEqual(pieces, pieces2)
|
||||
|
||||
def test_new_api_init(self):
|
||||
sp = spm.SentencePieceProcessor(
|
||||
model_file=os.path.join('test', 'test_model.model'),
|
||||
add_bos=True,
|
||||
add_eos=True,
|
||||
out_type=str,
|
||||
)
|
||||
text = 'hello world'
|
||||
pieces = ['<s>'] + self.sp_.EncodeAsPieces(text) + ['</s>']
|
||||
self.assertEqual(pieces, sp.encode(text))
|
||||
|
||||
pieces = self.sp_.EncodeAsPieces(text) + ['</s>']
|
||||
self.assertEqual(pieces, sp.encode(text, add_bos=False, add_eos=True))
|
||||
|
||||
def test_sampling(self):
|
||||
sp = self.sp_
|
||||
|
||||
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
|
||||
ids = defaultdict(int)
|
||||
for n in range(100):
|
||||
out = sp.encode('hello world', out_type=out_type, enable_sampling=True)
|
||||
if type(out) is list:
|
||||
out = tuple(out)
|
||||
++ids[out]
|
||||
self.assertGreater(len(ids), 1)
|
||||
|
||||
ids2 = defaultdict(int)
|
||||
for n in range(100):
|
||||
out = sp.encode('hello world', out_type=out_type, enable_sampling=False)
|
||||
if type(out) is list:
|
||||
out = tuple(out)
|
||||
++ids2[out]
|
||||
self.assertEqual(len(ids2), 1)
|
||||
|
||||
out = sp.encode(
|
||||
['hello world', 'this is a test'],
|
||||
out_type=out_type,
|
||||
enable_sampling=True,
|
||||
)
|
||||
self.assertEqual(len(out), 2)
|
||||
out = sp.encode(
|
||||
['hello world', 'this is a test'],
|
||||
out_type=out_type,
|
||||
enable_sampling=False,
|
||||
)
|
||||
self.assertEqual(len(out), 2)
|
||||
|
||||
def test_nbest(self):
|
||||
sp = self.sp_
|
||||
text = 'hello world'
|
||||
text2 = 'I have a pen.'
|
||||
|
||||
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
|
||||
results = sp.nbest_encode(text, nbest_size=10, out_type=out_type)
|
||||
self.assertEqual(
|
||||
results, sp.NBestEncode(text, nbest_size=10, out_type=out_type)
|
||||
)
|
||||
|
||||
if out_type in [str, int]:
|
||||
for n in results:
|
||||
self.assertEqual(sp.decode(n), text)
|
||||
|
||||
for n in sp.decode(results):
|
||||
self.assertEqual(n, text)
|
||||
|
||||
# batch test
|
||||
results = sp.nbest_encode([text, text2], nbest_size=10, out_type=out_type)
|
||||
self.assertEqual(
|
||||
results,
|
||||
sp.NBestEncode([text, text2], nbest_size=10, out_type=out_type),
|
||||
)
|
||||
self.assertEqual(len(results), 2)
|
||||
|
||||
if out_type in [str, int]:
|
||||
for n in results[0]:
|
||||
self.assertEqual(sp.decode(n), text)
|
||||
|
||||
for n in results[1]:
|
||||
self.assertEqual(sp.decode(n), text2)
|
||||
|
||||
decoded = sp.decode(results[0])
|
||||
self.assertEqual(len(decoded), 10)
|
||||
for n in decoded:
|
||||
self.assertEqual(n, text)
|
||||
decoded = sp.decode(results[1])
|
||||
self.assertEqual(len(decoded), 10)
|
||||
for n in decoded:
|
||||
self.assertEqual(n, text2)
|
||||
|
||||
self.assertEqual(
|
||||
sp.nbest_encode(text, nbest_size=10, out_type=str),
|
||||
sp.nbest_encode_as_pieces(text, nbest_size=10),
|
||||
)
|
||||
self.assertEqual(
|
||||
sp.nbest_encode(text, nbest_size=10, out_type=int),
|
||||
sp.nbest_encode_as_ids(text, nbest_size=10),
|
||||
)
|
||||
self.assertEqual(
|
||||
sp.nbest_encode(text, nbest_size=10, out_type='serialized_proto'),
|
||||
sp.nbest_encode_as_serialized_proto(text, nbest_size=10),
|
||||
)
|
||||
self.assertEqual(
|
||||
sp.nbest_encode(text, nbest_size=10, out_type='immutable_proto'),
|
||||
sp.nbest_encode_as_immutable_proto(text, nbest_size=10),
|
||||
)
|
||||
|
||||
def test_sample_and_score(self):
|
||||
sp = self.sp_
|
||||
text = 'hello world'
|
||||
text2 = 'I have a pen.'
|
||||
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
|
||||
results = sp.sample_encode_and_score(
|
||||
text, wor=True, num_samples=10, out_type=out_type
|
||||
)
|
||||
results = sp.SampleEncodeAndScore(
|
||||
text, wor=False, num_samples=10, out_type=out_type
|
||||
)
|
||||
|
||||
if out_type in [str, int]:
|
||||
for n in results:
|
||||
self.assertEqual(sp.decode(n[0]), text)
|
||||
|
||||
results = sp.sample_encode_and_score(
|
||||
[text, text2], wor=True, num_samples=10, out_type=out_type
|
||||
)
|
||||
results = sp.SampleEncodeAndScore(
|
||||
[text, text2], wor=True, num_samples=10, out_type=out_type
|
||||
)
|
||||
|
||||
if out_type in [str, int]:
|
||||
for n in results[0]:
|
||||
self.assertEqual(sp.decode(n[0]), text)
|
||||
for n in results[1]:
|
||||
self.assertEqual(sp.decode(n[0]), text2)
|
||||
|
||||
sp.sample_encode_and_score_as_pieces(text, 10)
|
||||
sp.sample_encode_and_score_as_ids(text, 10)
|
||||
sp.sample_encode_and_score_as_immutable_proto(text, 10)
|
||||
sp.sample_encode_and_score_as_serialized_proto(text, 10)
|
||||
|
||||
def test_valid_range(self):
|
||||
size = self.sp_.piece_size()
|
||||
funcs = [
|
||||
'IdToPiece',
|
||||
'GetScore',
|
||||
'IsUnknown',
|
||||
'IsControl',
|
||||
'IsUnused',
|
||||
'IsByte',
|
||||
'DecodeIds',
|
||||
'DecodeIdsAsSerializedProto',
|
||||
]
|
||||
for m in funcs:
|
||||
getattr(self.sp_, m)([10, 20, 30])
|
||||
|
||||
for m in funcs:
|
||||
try:
|
||||
getattr(self.sp_, m)([size])
|
||||
self.assertTrue(False)
|
||||
except:
|
||||
self.assertTrue(True)
|
||||
|
||||
def test_batch(self):
|
||||
sp = spm.SentencePieceProcessor(
|
||||
model_file=os.path.join('test', 'test_model.model')
|
||||
)
|
||||
with open(os.path.join(data_dir, 'botchan.txt'), 'r') as file:
|
||||
texts = file.readlines()
|
||||
|
||||
for out_type in [str, int, 'serialized_proto', 'immutable_proto']:
|
||||
r1 = sp.encode(texts, out_type=out_type, num_threads=None)
|
||||
r2 = sp.encode(texts, out_type=out_type, num_threads=1)
|
||||
r3 = sp.encode(texts, out_type=out_type, num_threads=-1)
|
||||
r4 = sp.encode(texts, out_type=out_type, num_threads=8)
|
||||
r5 = [sp.encode(s, out_type=out_type) for s in texts]
|
||||
self.assertEqual(r1, r2)
|
||||
self.assertEqual(r1, r3)
|
||||
self.assertEqual(r1, r4)
|
||||
self.assertEqual(r1, r5)
|
||||
|
||||
if out_type in [str, int]:
|
||||
d1 = sp.decode(r1, num_threads=None)
|
||||
d2 = sp.decode(r2, num_threads=1)
|
||||
d3 = sp.decode(r3, num_threads=-1)
|
||||
d4 = sp.decode(r4, num_threads=8)
|
||||
d5 = [sp.decode(s) for s in r5]
|
||||
|
||||
self.assertEqual(d1, d2)
|
||||
self.assertEqual(d1, d3)
|
||||
self.assertEqual(d1, d4)
|
||||
self.assertEqual(d1, d5)
|
||||
|
||||
e1 = sp.calculate_entropy(texts, alpha=1.0, num_threads=10)
|
||||
e2 = sp.CalculateEntropy(texts, alpha=1.0, num_threads=10)
|
||||
e3 = [sp.calculate_entropy(s, alpha=1.0) for s in texts]
|
||||
self.assertEqual(e1, e2)
|
||||
self.assertEqual(e1, e3)
|
||||
|
||||
def test_pickle(self):
|
||||
with open('sp.pickle', 'wb') as f:
|
||||
pickle.dump(self.sp_, f)
|
||||
|
||||
id1 = self.sp_.encode('hello world.', out_type=int)
|
||||
|
||||
with open('sp.pickle', 'rb') as f:
|
||||
sp = pickle.load(f)
|
||||
|
||||
id2 = sp.encode('hello world.', out_type=int)
|
||||
|
||||
self.assertEqual(id1, id2)
|
||||
|
||||
def test_global_params(self):
|
||||
spm.SetRandomGeneratorSeed(0)
|
||||
spm.SetMinLogLevel(2)
|
||||
spm.set_random_generator_seed(1)
|
||||
spm.set_min_log_level(3)
|
||||
|
||||
def test_normalize(self):
|
||||
sp = spm.SentencePieceProcessor(
|
||||
model_file=os.path.join('test', 'test_model.model')
|
||||
)
|
||||
|
||||
self.assertEqual('▁KADOKAWAABC', sp.normalize('KADOKAWAABC'))
|
||||
self.assertEqual('▁KADOKAWAABC', sp.Normalize('KADOKAWAABC'))
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC', with_offsets=True)
|
||||
self.assertEqual('▁KADOKAWAABC', x[0])
|
||||
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1])
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True)
|
||||
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0])
|
||||
self.assertEqual(
|
||||
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1]
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
['▁KADOKAWAABC', '▁平成'], sp.normalize(['KADOKAWAABC', '㍻'])
|
||||
)
|
||||
self.assertEqual(
|
||||
['▁KADOKAWAABC', '▁平成'], sp.Normalize(['KADOKAWAABC', '㍻'])
|
||||
)
|
||||
|
||||
x = sp.Normalize(
|
||||
['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')],
|
||||
with_offsets=True,
|
||||
)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('▁KADOKAWAABC'.encode('utf8'), x[0][0])
|
||||
self.assertEqual(
|
||||
[0, 0, 0, 0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1]
|
||||
)
|
||||
|
||||
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('▁KADOKAWAABC', x[0][0])
|
||||
self.assertEqual([0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1])
|
||||
|
||||
self.assertEqual('▁平成', x[1][0])
|
||||
self.assertEqual([0, 0, 0, 1], x[1][1])
|
||||
|
||||
def test_normalizer(self):
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model')
|
||||
)
|
||||
|
||||
self.assertEqual('KADOKAWAABC', sp.normalize('KADOKAWAABC'))
|
||||
self.assertEqual('KADOKAWAABC', sp.Normalize('KADOKAWAABC'))
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC'.encode('utf8'), with_offsets=True)
|
||||
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0])
|
||||
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[1])
|
||||
|
||||
x = sp.Normalize('KADOKAWAABC', with_offsets=True)
|
||||
self.assertEqual('KADOKAWAABC', x[0])
|
||||
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[1])
|
||||
|
||||
self.assertEqual(
|
||||
['KADOKAWAABC', '平成'], sp.normalize(['KADOKAWAABC', '㍻'])
|
||||
)
|
||||
self.assertEqual(
|
||||
['KADOKAWAABC', '平成'], sp.Normalize(['KADOKAWAABC', '㍻'])
|
||||
)
|
||||
|
||||
x = sp.Normalize(
|
||||
['KADOKAWAABC'.encode('utf8'), '㍻'.encode('utf8')],
|
||||
with_offsets=True,
|
||||
)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('KADOKAWAABC'.encode('utf8'), x[0][0])
|
||||
self.assertEqual([0, 3, 6, 9, 12, 15, 18, 21, 24, 25, 26, 27], x[0][1])
|
||||
|
||||
x = sp.Normalize(['KADOKAWAABC', '㍻'], with_offsets=True)
|
||||
self.assertEqual(len(x), 2)
|
||||
self.assertEqual('KADOKAWAABC', x[0][0])
|
||||
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], x[0][1])
|
||||
self.assertEqual('平成', x[1][0])
|
||||
self.assertEqual([0, 0, 1], x[1][1])
|
||||
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model'),
|
||||
add_dummy_prefix=True,
|
||||
escape_whitespaces=True,
|
||||
remove_extra_whitespaces=False,
|
||||
)
|
||||
self.assertEqual('▁hello▁▁world', sp.normalize('hello world'))
|
||||
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model'),
|
||||
add_dummy_prefix=True,
|
||||
escape_whitespaces=True,
|
||||
remove_extra_whitespaces=True,
|
||||
)
|
||||
self.assertEqual('▁hello▁world', sp.normalize(' hello world '))
|
||||
|
||||
sp = spm.SentencePieceNormalizer(
|
||||
model_file=os.path.join('test', 'test_model.model'),
|
||||
add_dummy_prefix=False,
|
||||
escape_whitespaces=False,
|
||||
remove_extra_whitespaces=True,
|
||||
)
|
||||
self.assertEqual('hello world', sp.normalize(' hello world '))
|
||||
|
||||
def test_normalizer_rule(self):
|
||||
sp = spm.SentencePieceNormalizer(rule_name='identity')
|
||||
self.assertEqual('ABC', sp.Normalize('ABC'))
|
||||
|
||||
sp = spm.SentencePieceNormalizer(rule_name='nfkc_cf')
|
||||
self.assertEqual('abc', sp.Normalize('ABC'))
|
||||
|
||||
def test_override_normalize_spec(self):
|
||||
sp = spm.SentencePieceProcessor(
|
||||
model_file=os.path.join('test', 'test_model.model')
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
sp.EncodeAsPieces(' hello world '), ['▁he', 'll', 'o', '▁world']
|
||||
)
|
||||
|
||||
sp.override_normalizer_spec(add_dummy_prefix=False)
|
||||
sp.override_normalizer_spec(remove_extra_whitespaces=False)
|
||||
sp.override_normalizer_spec(escape_whitespaces=False)
|
||||
self.assertEqual(
|
||||
sp.EncodeAsPieces(' hello world '),
|
||||
[' ', 'he', 'll', 'o', ' ', 'w', 'or', 'l', 'd', ' '],
|
||||
)
|
||||
|
||||
|
||||
def suite():
|
||||
suite = unittest.TestSuite()
|
||||
suite.addTests(unittest.makeSuite(TestSentencepieceProcessor))
|
||||
return suite
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Binary file not shown.
Binary file not shown.
|
|
@ -1,11 +0,0 @@
|
|||
prefix=@prefix@
|
||||
exec_prefix=@exec_prefix@
|
||||
libdir=@libdir_for_pc_file@
|
||||
includedir=@includedir_for_pc_file@
|
||||
|
||||
Name: @PROJECT_NAME@
|
||||
Description: Unsupervised text tokenizer and detokenizer for Neural Network-based text generation.
|
||||
Version: @PROJECT_VERSION@
|
||||
Libs: -L${libdir} -lsentencepiece -lsentencepiece_train
|
||||
Cflags: -I${includedir}
|
||||
Requires.private: @libprotobuf_lite@
|
||||
|
|
@ -0,0 +1,235 @@
|
|||
//
|
||||
// Created by Local on 02/12/2025.
|
||||
//
|
||||
|
||||
#include <jni.h>
|
||||
#include <string>
|
||||
#include <android/log.h>
|
||||
#include "libs/bergamot/src/translator/byte_array_util.h"
|
||||
#include "libs/bergamot/src/translator/parser.h"
|
||||
#include "libs/bergamot/src/translator/response.h"
|
||||
#include "libs/bergamot/src/translator/response_options.h"
|
||||
#include "libs/bergamot/src/translator/service.h"
|
||||
#include "libs/bergamot/src/translator/utils.h"
|
||||
#include <string>
|
||||
using namespace marian::bergamot;
|
||||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
|
||||
struct ModelContainer{
|
||||
public:
|
||||
std::shared_ptr<TranslationModel> toEnglishModel;
|
||||
std::shared_ptr<TranslationModel> fromEnglishModel;
|
||||
|
||||
ModelContainer(std::shared_ptr<TranslationModel> toEnglishModel, std::shared_ptr<TranslationModel> fromEnglishModel){
|
||||
this->toEnglishModel = toEnglishModel;
|
||||
this->fromEnglishModel = fromEnglishModel;
|
||||
}
|
||||
};
|
||||
|
||||
static std::unordered_map<std::string, std::shared_ptr<ModelContainer>> model_cache;
|
||||
static std::unique_ptr<BlockingService> global_service = nullptr;
|
||||
static std::mutex service_mutex;
|
||||
static std::mutex translation_mutex;
|
||||
|
||||
void initializeService() {
|
||||
std::lock_guard<std::mutex> lock(service_mutex);
|
||||
|
||||
if (global_service == nullptr) {
|
||||
BlockingService::Config blockingConfig;
|
||||
blockingConfig.cacheSize = 0; //todo: change it back to 256 in release
|
||||
blockingConfig.logger.level = "off";
|
||||
global_service = std::make_unique<BlockingService>(blockingConfig);
|
||||
}
|
||||
}
|
||||
|
||||
void loadModelIntoCache(const std::string& toEngCfg, const std::string& fromEngConfig, const std::string& lang) {
|
||||
std::lock_guard<std::mutex> lock(service_mutex);
|
||||
|
||||
auto validate = true;
|
||||
auto pathsDir = "";
|
||||
|
||||
if (model_cache.find(lang) == model_cache.end()) {
|
||||
auto toEngOptions = parseOptionsFromString(toEngCfg, validate, pathsDir);
|
||||
auto fromEngOptions = parseOptionsFromString(fromEngConfig, validate, pathsDir);
|
||||
model_cache[lang] = std::make_shared<ModelContainer>(std::make_shared<TranslationModel>(toEngOptions),
|
||||
std::make_shared<TranslationModel>(fromEngOptions));
|
||||
}
|
||||
}
|
||||
|
||||
void unloadModelFromCache(const std::string& lang) {
|
||||
std::lock_guard<std::mutex> lock(service_mutex);
|
||||
|
||||
model_cache.erase(lang);
|
||||
}
|
||||
|
||||
std::vector<std::string> translateMultiple(std::vector<std::string> &&inputs, const std::string& srcLang, const std::string& trgLang) {
|
||||
initializeService();
|
||||
|
||||
std::shared_ptr<TranslationModel> firstModel = nullptr;
|
||||
std::shared_ptr<TranslationModel> secondModel = nullptr;
|
||||
|
||||
if(srcLang == trgLang) return inputs;
|
||||
|
||||
// Assume models are already loaded in cache
|
||||
if (srcLang != "en") {
|
||||
if(model_cache.find(srcLang) == model_cache.end()) throw std::runtime_error("Missing loaded src model");
|
||||
firstModel = model_cache[srcLang]->toEnglishModel;
|
||||
}
|
||||
if (trgLang != "en") {
|
||||
if(model_cache.find(trgLang) == model_cache.end()) throw std::runtime_error("Missing loaded trg model");
|
||||
secondModel = model_cache[trgLang]->fromEnglishModel;
|
||||
}
|
||||
|
||||
std::vector<ResponseOptions> responseOptions;
|
||||
responseOptions.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
ResponseOptions opts;
|
||||
opts.HTML = false;
|
||||
opts.qualityScores = false;
|
||||
opts.alignment = false;
|
||||
opts.sentenceMappings = false;
|
||||
responseOptions.emplace_back(opts);
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> translation_lock(translation_mutex);
|
||||
std::vector<Response> responses;
|
||||
if (firstModel != nullptr && secondModel != nullptr) {
|
||||
responses = global_service->pivotMultiple(firstModel, secondModel, std::move(inputs),
|
||||
responseOptions);
|
||||
} else if (firstModel != nullptr) {
|
||||
responses = global_service->translateMultiple(firstModel, std::move(inputs),
|
||||
responseOptions);
|
||||
} else if (secondModel != nullptr) {
|
||||
responses = global_service->translateMultiple(secondModel, std::move(inputs),
|
||||
responseOptions);
|
||||
} else {
|
||||
throw std::runtime_error("Missing loaded models");
|
||||
}
|
||||
|
||||
std::vector<std::string> results;
|
||||
results.reserve(responses.size());
|
||||
for (const auto &response: responses) {
|
||||
results.push_back(response.target.text);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
void cleanup(){
|
||||
std::lock_guard<std::mutex> lock(service_mutex);
|
||||
global_service.reset();
|
||||
model_cache.clear();
|
||||
}
|
||||
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT void JNICALL
|
||||
Java_nie_translator_rtranslator_voice_1translation_neural_1networks_translation_BergamotTranslator_initializeServiceNative(
|
||||
JNIEnv* env,
|
||||
jclass /* this */) {
|
||||
try {
|
||||
initializeService();
|
||||
} catch(const std::exception &e) {
|
||||
jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
|
||||
env->ThrowNew(exceptionClass, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT void JNICALL
|
||||
Java_nie_translator_rtranslator_voice_1translation_neural_1networks_translation_BergamotTranslator_loadModelIntoCacheNative(
|
||||
JNIEnv* env,
|
||||
jclass /* this */,
|
||||
jstring toEngCfg,
|
||||
jstring fromEngCfg,
|
||||
jstring lang) {
|
||||
|
||||
const char* c_toEngCfg = env->GetStringUTFChars(toEngCfg, nullptr);
|
||||
const char* c_fromEngCfg = env->GetStringUTFChars(fromEngCfg, nullptr);
|
||||
const char* c_lang = env->GetStringUTFChars(lang, nullptr);
|
||||
|
||||
try {
|
||||
std::string toEngCfg_str(c_toEngCfg);
|
||||
std::string fromEngCfg_str(c_fromEngCfg);
|
||||
std::string key_str(c_lang);
|
||||
loadModelIntoCache(toEngCfg_str, fromEngCfg_str, key_str);
|
||||
} catch(const std::exception &e) {
|
||||
jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
|
||||
env->ThrowNew(exceptionClass, e.what());
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(toEngCfg, c_toEngCfg);
|
||||
env->ReleaseStringUTFChars(fromEngCfg, c_fromEngCfg);
|
||||
env->ReleaseStringUTFChars(lang, c_lang);
|
||||
}
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT void JNICALL
|
||||
Java_nie_translator_rtranslator_voice_1translation_neural_1networks_translation_BergamotTranslator_unloadModelFromCacheNative(
|
||||
JNIEnv* env,
|
||||
jclass /* this */,
|
||||
jstring lang) {
|
||||
|
||||
const char* c_lang = env->GetStringUTFChars(lang, nullptr);
|
||||
|
||||
try {
|
||||
std::string lang_str(c_lang);
|
||||
unloadModelFromCache(lang_str);
|
||||
} catch(const std::exception &e) {
|
||||
jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
|
||||
env->ThrowNew(exceptionClass, e.what());
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(lang, c_lang);
|
||||
}
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT jobjectArray JNICALL
|
||||
Java_nie_translator_rtranslator_voice_1translation_neural_1networks_translation_BergamotTranslator_translateMultipleNative(
|
||||
JNIEnv *env,
|
||||
jclass /* this */,
|
||||
jobjectArray inputs,
|
||||
jstring srcLang,
|
||||
jstring trgLang) {
|
||||
|
||||
const char *c_srcLang = env->GetStringUTFChars(srcLang, nullptr);
|
||||
const char *c_trgLang = env->GetStringUTFChars(trgLang, nullptr);
|
||||
|
||||
jsize inputCount = env->GetArrayLength(inputs);
|
||||
std::vector<std::string> cpp_inputs;
|
||||
cpp_inputs.reserve(inputCount);
|
||||
|
||||
for (jsize i = 0; i < inputCount; i++) {
|
||||
auto jstr = (jstring) env->GetObjectArrayElement(inputs, i);
|
||||
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
|
||||
cpp_inputs.emplace_back(c_str);
|
||||
env->ReleaseStringUTFChars(jstr, c_str);
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
|
||||
jobjectArray result = nullptr;
|
||||
try {
|
||||
std::string srcLang_str(c_srcLang);
|
||||
std::string trgLang_str(c_trgLang);
|
||||
std::vector<std::string> translations = translateMultiple(std::move(cpp_inputs), srcLang_str, trgLang_str);
|
||||
|
||||
jclass stringClass = env->FindClass("java/lang/String");
|
||||
result = env->NewObjectArray((jsize) translations.size(), stringClass, nullptr);
|
||||
|
||||
for (size_t i = 0; i < translations.size(); ++i) {
|
||||
jstring jstr = env->NewStringUTF(translations[i].c_str());
|
||||
env->SetObjectArrayElement(result, (jsize) i, jstr);
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
|
||||
env->ThrowNew(exceptionClass, e.what());
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(srcLang, c_srcLang);
|
||||
env->ReleaseStringUTFChars(trgLang, c_trgLang);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT void JNICALL
|
||||
Java_nie_translator_rtranslator_voice_1translation_neural_1networks_translation_BergamotTranslator_cleanupNative(JNIEnv* env, jclass /* this */) {
|
||||
cleanup();
|
||||
}
|
||||
|
|
@ -0,0 +1,241 @@
|
|||
//
|
||||
// Created by Local on 27/11/2025.
|
||||
//
|
||||
|
||||
#include <jni.h>
|
||||
#include <string>
|
||||
#include <android/log.h>
|
||||
#include "libs/bergamot/src/translator/byte_array_util.h"
|
||||
#include "libs/bergamot/src/translator/parser.h"
|
||||
#include "libs/bergamot/src/translator/response.h"
|
||||
#include "libs/bergamot/src/translator/response_options.h"
|
||||
#include "libs/bergamot/src/translator/service.h"
|
||||
#include "libs/bergamot/src/translator/utils.h"
|
||||
#include <string>
|
||||
using namespace marian::bergamot;
|
||||
|
||||
#include <unordered_map>
|
||||
#include <mutex>
|
||||
static std::unordered_map<std::string, std::shared_ptr<TranslationModel>> model_cache;
|
||||
static std::unique_ptr<BlockingService> global_service = nullptr;
|
||||
static std::mutex service_mutex;
|
||||
static std::mutex translation_mutex;
|
||||
|
||||
void initializeService() {
|
||||
std::lock_guard<std::mutex> lock(service_mutex);
|
||||
|
||||
if (global_service == nullptr) {
|
||||
BlockingService::Config blockingConfig;
|
||||
blockingConfig.cacheSize = 256;
|
||||
blockingConfig.logger.level = "off";
|
||||
global_service = std::make_unique<BlockingService>(blockingConfig);
|
||||
}
|
||||
}
|
||||
|
||||
void loadModelIntoCache(const std::string& cfg, const std::string& key) {
|
||||
std::lock_guard<std::mutex> lock(service_mutex);
|
||||
|
||||
auto validate = true;
|
||||
auto pathsDir = "";
|
||||
|
||||
if (model_cache.find(key) == model_cache.end()) {
|
||||
auto options = parseOptionsFromString(cfg, validate, pathsDir);
|
||||
model_cache[key] = std::make_shared<TranslationModel>(options);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> translateMultiple(std::vector<std::string> &&inputs, const char *key) {
|
||||
initializeService();
|
||||
|
||||
std::string key_str(key);
|
||||
|
||||
// Assume model is already loaded in cache
|
||||
std::shared_ptr<TranslationModel> model = model_cache[key_str];
|
||||
|
||||
std::vector<ResponseOptions> responseOptions;
|
||||
responseOptions.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
ResponseOptions opts;
|
||||
opts.HTML = false;
|
||||
opts.qualityScores = false;
|
||||
opts.alignment = false;
|
||||
opts.sentenceMappings = false;
|
||||
responseOptions.emplace_back(opts);
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> translation_lock(translation_mutex);
|
||||
std::vector<Response> responses = global_service->translateMultiple(model, std::move(inputs), responseOptions);
|
||||
|
||||
std::vector<std::string> results;
|
||||
results.reserve(responses.size());
|
||||
for (const auto &response: responses) {
|
||||
results.push_back(response.target.text);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
std::vector<std::string> pivotMultiple(const char *firstKey, const char *secondKey, std::vector<std::string> &&inputs) {
|
||||
initializeService();
|
||||
|
||||
std::string first_key_str(firstKey);
|
||||
std::string second_key_str(secondKey);
|
||||
|
||||
// Assume models are already loaded in cache
|
||||
std::shared_ptr<TranslationModel> firstModel = model_cache[first_key_str];
|
||||
std::shared_ptr<TranslationModel> secondModel = model_cache[second_key_str];
|
||||
|
||||
std::vector<ResponseOptions> responseOptions;
|
||||
responseOptions.reserve(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); ++i) {
|
||||
ResponseOptions opts;
|
||||
opts.HTML = false;
|
||||
opts.qualityScores = false;
|
||||
opts.alignment = false;
|
||||
opts.sentenceMappings = false;
|
||||
responseOptions.emplace_back(opts);
|
||||
}
|
||||
|
||||
std::lock_guard<std::mutex> translation_lock(translation_mutex);
|
||||
std::vector<Response> responses = global_service->pivotMultiple(firstModel, secondModel, std::move(inputs), responseOptions);
|
||||
|
||||
std::vector<std::string> results;
|
||||
results.reserve(responses.size());
|
||||
for (const auto &response: responses) {
|
||||
results.push_back(response.target.text);
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT void JNICALL
|
||||
Java_dev_davidv_bergamot_NativeLib_initializeService(
|
||||
JNIEnv* env,
|
||||
jobject /* this */) {
|
||||
try {
|
||||
initializeService();
|
||||
} catch(const std::exception &e) {
|
||||
jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
|
||||
env->ThrowNew(exceptionClass, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT void JNICALL
|
||||
Java_dev_davidv_bergamot_NativeLib_loadModelIntoCache(
|
||||
JNIEnv* env,
|
||||
jobject /* this */,
|
||||
jstring cfg,
|
||||
jstring key) {
|
||||
|
||||
const char* c_cfg = env->GetStringUTFChars(cfg, nullptr);
|
||||
const char* c_key = env->GetStringUTFChars(key, nullptr);
|
||||
|
||||
try {
|
||||
std::string cfg_str(c_cfg);
|
||||
std::string key_str(c_key);
|
||||
loadModelIntoCache(cfg_str, key_str);
|
||||
} catch(const std::exception &e) {
|
||||
jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
|
||||
env->ThrowNew(exceptionClass, e.what());
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(cfg, c_cfg);
|
||||
env->ReleaseStringUTFChars(key, c_key);
|
||||
}
|
||||
|
||||
// Cleanup function to be called when the library is unloaded
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT jobjectArray JNICALL
|
||||
Java_dev_davidv_bergamot_NativeLib_translateMultiple(
|
||||
JNIEnv *env,
|
||||
jobject /* this */,
|
||||
jobjectArray inputs,
|
||||
jstring key) {
|
||||
|
||||
const char *c_key = env->GetStringUTFChars(key, nullptr);
|
||||
|
||||
jsize inputCount = env->GetArrayLength(inputs);
|
||||
std::vector<std::string> cpp_inputs;
|
||||
cpp_inputs.reserve(inputCount);
|
||||
|
||||
for (jsize i = 0; i < inputCount; i++) {
|
||||
auto jstr = (jstring) env->GetObjectArrayElement(inputs, i);
|
||||
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
|
||||
cpp_inputs.emplace_back(c_str);
|
||||
env->ReleaseStringUTFChars(jstr, c_str);
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
|
||||
jobjectArray result = nullptr;
|
||||
try {
|
||||
std::vector<std::string> translations = translateMultiple(std::move(cpp_inputs), c_key);
|
||||
|
||||
jclass stringClass = env->FindClass("java/lang/String");
|
||||
result = env->NewObjectArray((jsize) translations.size(), stringClass, nullptr);
|
||||
|
||||
for (size_t i = 0; i < translations.size(); ++i) {
|
||||
jstring jstr = env->NewStringUTF(translations[i].c_str());
|
||||
env->SetObjectArrayElement(result, (jsize) i, jstr);
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
|
||||
env->ThrowNew(exceptionClass, e.what());
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(key, c_key);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT jobjectArray JNICALL
|
||||
Java_dev_davidv_bergamot_NativeLib_pivotMultiple(
|
||||
JNIEnv *env,
|
||||
jobject /* this */,
|
||||
jstring firstKey,
|
||||
jstring secondKey,
|
||||
jobjectArray inputs) {
|
||||
|
||||
const char *c_firstKey = env->GetStringUTFChars(firstKey, nullptr);
|
||||
const char *c_secondKey = env->GetStringUTFChars(secondKey, nullptr);
|
||||
|
||||
jsize inputCount = env->GetArrayLength(inputs);
|
||||
std::vector<std::string> cpp_inputs;
|
||||
cpp_inputs.reserve(inputCount);
|
||||
|
||||
for (jsize i = 0; i < inputCount; i++) {
|
||||
auto jstr = (jstring) env->GetObjectArrayElement(inputs, i);
|
||||
const char *c_str = env->GetStringUTFChars(jstr, nullptr);
|
||||
cpp_inputs.emplace_back(c_str);
|
||||
env->ReleaseStringUTFChars(jstr, c_str);
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
|
||||
jobjectArray result = nullptr;
|
||||
try {
|
||||
std::vector<std::string> translations = pivotMultiple(c_firstKey, c_secondKey, std::move(cpp_inputs));
|
||||
|
||||
jclass stringClass = env->FindClass("java/lang/String");
|
||||
result = env->NewObjectArray((jsize) translations.size(), stringClass, nullptr);
|
||||
|
||||
for (size_t i = 0; i < translations.size(); ++i) {
|
||||
jstring jstr = env->NewStringUTF(translations[i].c_str());
|
||||
env->SetObjectArrayElement(result, (jsize) i, jstr);
|
||||
env->DeleteLocalRef(jstr);
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
jclass exceptionClass = env->FindClass("java/lang/RuntimeException");
|
||||
env->ThrowNew(exceptionClass, e.what());
|
||||
}
|
||||
|
||||
env->ReleaseStringUTFChars(firstKey, c_firstKey);
|
||||
env->ReleaseStringUTFChars(secondKey, c_secondKey);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
extern "C" __attribute__((visibility("default"))) JNIEXPORT void JNICALL
|
||||
Java_dev_davidv_bergamot_NativeLib_cleanup(JNIEnv* env, jobject /* this */) {
|
||||
std::lock_guard<std::mutex> lock(service_mutex);
|
||||
global_service.reset();
|
||||
model_cache.clear();
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2018 Google Inc.
|
||||
# Copyright 2016 Luca Martino.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
|
@ -12,369 +12,37 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.!
|
||||
|
||||
if (SPM_ABSL_PROVIDER STREQUAL "module" OR SPM_ABSL_PROVIDER STREQUAL "package")
|
||||
set(ABSL_FLAGS_SRCS "")
|
||||
set(ABSL_STRINGS_SRCS "")
|
||||
list(APPEND SPM_LIBS absl::strings)
|
||||
list(APPEND SPM_LIBS absl::flags)
|
||||
list(APPEND SPM_LIBS absl::flags_parse)
|
||||
list(APPEND SPM_LIBS absl::log)
|
||||
list(APPEND SPM_LIBS absl::check)
|
||||
if (MSVC)
|
||||
add_definitions("/D_USE_EXTERNAL_ABSL")
|
||||
else()
|
||||
add_definitions("-D_USE_EXTERNAL_ABSL")
|
||||
endif()
|
||||
elseif (SPM_ABSL_PROVIDER STREQUAL "internal")
|
||||
set(ABSL_FLAGS_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/../third_party/absl/flags/flag.cc)
|
||||
endif()
|
||||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
|
||||
project(rtranslator_native)
|
||||
|
||||
if (SPM_PROTOBUF_PROVIDER STREQUAL "internal")
|
||||
set(SPM_PROTO_HDRS builtin_pb/sentencepiece.pb.h)
|
||||
set(SPM_PROTO_SRCS builtin_pb/sentencepiece.pb.cc)
|
||||
set(SPM_MODEL_PROTO_HDRS builtin_pb/sentencepiece_model.pb.h)
|
||||
set(SPM_MODEL_PROTO_SRCS builtin_pb/sentencepiece_model.pb.cc)
|
||||
set(PROTOBUF_LITE_LIBRARY "")
|
||||
set(PROTOBUF_LITE_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/arena.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/arenastring.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/bytestream.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/coded_stream.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/common.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/extension_set.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/generated_enum_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/generated_message_table_driven_lite.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/generated_message_util.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/implicit_weak_message.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/int128.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/io_win32.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/message_lite.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/parse_context.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/repeated_field.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/status.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/statusor.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/stringpiece.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/stringprintf.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/structurally_valid.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/strutil.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/time.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/wire_format_lite.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/zero_copy_stream.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/zero_copy_stream_impl.cc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite/zero_copy_stream_impl_lite.cc)
|
||||
if (MSVC)
|
||||
add_definitions("/DHAVE_PTHREAD /wd4018 /wd4514")
|
||||
else()
|
||||
add_definitions("-pthread -DHAVE_PTHREAD=1 -Wno-sign-compare -Wno-deprecated-declarations")
|
||||
endif()
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third_party/protobuf-lite)
|
||||
include_directories(builtin_pb)
|
||||
elseif (SPM_PROTOBUF_PROVIDER STREQUAL "package")
|
||||
find_package(Protobuf REQUIRED)
|
||||
include_directories(${Protobuf_INCLUDE_DIRS})
|
||||
protobuf_generate_cpp(SPM_PROTO_SRCS SPM_PROTO_HDRS sentencepiece.proto)
|
||||
protobuf_generate_cpp(SPM_MODEL_PROTO_SRCS SPM_MODEL_PROTO_HDRS sentencepiece_model.proto)
|
||||
set(PROTOBUF_LITE_SRCS "")
|
||||
include_directories(${PROTOBUF_INCLUDE_DIR})
|
||||
if (MSVC)
|
||||
add_definitions("/D_USE_EXTERNAL_PROTOBUF")
|
||||
else()
|
||||
add_definitions("-D_USE_EXTERNAL_PROTOBUF")
|
||||
endif()
|
||||
endif()
|
||||
add_subdirectory(libs/sentencepiece)
|
||||
add_subdirectory(libs/bergamot)
|
||||
|
||||
include_directories(${CMAKE_CURRENT_BINARY_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../third_party)
|
||||
|
||||
if (MSVC)
|
||||
add_definitions("/D_USE_INTERNAL_STRING_VIEW")
|
||||
else()
|
||||
add_definitions("-D_USE_INTERNAL_STRING_VIEW")
|
||||
endif()
|
||||
|
||||
set(SPM_SRCS
|
||||
${PROTOBUF_LITE_SRCS}
|
||||
${SPM_PROTO_HDRS}
|
||||
${SPM_PROTO_SRCS}
|
||||
${SPM_MODEL_PROTO_HDRS}
|
||||
${SPM_MODEL_PROTO_SRCS}
|
||||
bpe_model.h
|
||||
common.h
|
||||
normalizer.h
|
||||
util.h
|
||||
freelist.h
|
||||
filesystem.h
|
||||
init.h
|
||||
sentencepiece_processor.h
|
||||
word_model.h
|
||||
model_factory.h
|
||||
char_model.h
|
||||
model_interface.h
|
||||
testharness.h
|
||||
unigram_model.h
|
||||
bpe_model.cc
|
||||
char_model.cc
|
||||
error.cc
|
||||
filesystem.cc
|
||||
model_factory.cc
|
||||
model_interface.cc
|
||||
normalizer.cc
|
||||
sentencepiece_processor.cc
|
||||
unigram_model.cc
|
||||
util.cc
|
||||
word_model.cc
|
||||
SentencePieceProcessorInterface.cpp
|
||||
NNAPITest.cpp
|
||||
CacheContainerNative.cpp
|
||||
${ABSL_STRINGS_SRCS}
|
||||
${ABSL_FLAGS_SRCS})
|
||||
|
||||
set(SPM_TRAIN_SRCS
|
||||
${SPM_PROTO_HDRS}
|
||||
${SPM_MODEL_PROTO_HDRS}
|
||||
builder.h
|
||||
normalization_rule.h
|
||||
unicode_script.h
|
||||
unicode_script_map.h
|
||||
trainer_factory.h
|
||||
trainer_interface.h
|
||||
unigram_model_trainer.h
|
||||
word_model_trainer.h
|
||||
char_model_trainer.h
|
||||
bpe_model_trainer.h
|
||||
sentencepiece_trainer.h
|
||||
pretokenizer_for_training.h
|
||||
builder.cc
|
||||
unicode_script.cc
|
||||
trainer_factory.cc
|
||||
trainer_interface.cc
|
||||
unigram_model_trainer.cc
|
||||
word_model_trainer.cc
|
||||
char_model_trainer.cc
|
||||
bpe_model_trainer.cc
|
||||
sentencepiece_trainer.cc
|
||||
pretokenizer_for_training.cc)
|
||||
|
||||
set(SPM_TEST_SRCS
|
||||
${SPM_PROTO_HDRS}
|
||||
${SPM_MODEL_PROTO_HDRS}
|
||||
testharness.h
|
||||
bpe_model_test.cc
|
||||
bpe_model_trainer_test.cc
|
||||
builder_test.cc
|
||||
char_model_test.cc
|
||||
char_model_trainer_test.cc
|
||||
filesystem_test.cc
|
||||
init_test.cc
|
||||
model_factory_test.cc
|
||||
model_interface_test.cc
|
||||
normalizer_test.cc
|
||||
sentencepiece_processor_test.cc
|
||||
sentencepiece_trainer_test.cc
|
||||
test_main.cc
|
||||
testharness.cc
|
||||
trainer_factory_test.cc
|
||||
trainer_interface_test.cc
|
||||
unicode_script_test.cc
|
||||
unigram_model_test.cc
|
||||
unigram_model_trainer_test.cc
|
||||
util_test.cc
|
||||
word_model_test.cc
|
||||
word_model_trainer_test.cc
|
||||
pretokenizer_for_training_test.cc)
|
||||
|
||||
find_package(Threads REQUIRED)
|
||||
|
||||
list(APPEND SPM_LIBS ${PROTOBUF_LITE_LIBRARY} Threads::Threads)
|
||||
|
||||
if (SPM_ENABLE_NFKC_COMPILE)
|
||||
find_package(ICU 4.4 COMPONENTS i18n data uc REQUIRED)
|
||||
include_directories(${ICU_INCLUDE_DIRS})
|
||||
add_definitions(-DENABLE_NFKC_COMPILE)
|
||||
list(APPEND SPM_LIBS ICU::i18n ICU::data ICU::uc)
|
||||
endif()
|
||||
|
||||
if (SPM_ENABLE_TCMALLOC)
|
||||
if (SPM_TCMALLOC_STATIC)
|
||||
find_library(TCMALLOC_LIB NAMES libtcmalloc_minimal.a)
|
||||
else()
|
||||
find_library(TCMALLOC_LIB NAMES tcmalloc_minimal)
|
||||
endif()
|
||||
if (TCMALLOC_LIB)
|
||||
message(STATUS "Found TCMalloc: ${TCMALLOC_LIB}")
|
||||
list(APPEND SPM_LIBS ${TCMALLOC_LIB})
|
||||
add_definitions(-fno-builtin-malloc -fno-builtin-calloc -fno-builtin-realloc -fno-builtin-free)
|
||||
else()
|
||||
message(STATUS "Not Found TCMalloc: ${TCMALLOC_LIB}")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR
|
||||
(${CMAKE_SYSTEM_PROCESSOR} MATCHES "mips") OR
|
||||
(${CMAKE_SYSTEM_PROCESSOR} MATCHES "m68k") OR
|
||||
(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc") OR
|
||||
(${CMAKE_SYSTEM_PROCESSOR} MATCHES "powerpc") OR
|
||||
(${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch") OR
|
||||
(${CMAKE_SYSTEM_PROCESSOR} MATCHES "sh4"))
|
||||
find_library(ATOMIC_LIB NAMES atomic libatomic.so libatomic.so.1)
|
||||
if (ATOMIC_LIB)
|
||||
message(STATUS "Found atomic: ${ATOMIC_LIB}")
|
||||
list(APPEND SPM_LIBS "atomic")
|
||||
endif()
|
||||
endif()
|
||||
add_library(sentencepiece-interface SHARED SentencePieceProcessorInterface.cpp)
|
||||
add_library(cache_container_native SHARED CacheContainerNative.cpp)
|
||||
add_library(bergamot_translator_interface SHARED BergamotTranslator.cpp)
|
||||
add_library(bergamot_translator_interface_old SHARED BergamotTranslator_old.cpp)
|
||||
|
||||
|
||||
if (SPM_ENABLE_SHARED)
|
||||
add_library(sentencepiece SHARED ${SPM_SRCS})
|
||||
add_library(sentencepiece_train SHARED ${SPM_TRAIN_SRCS})
|
||||
if (ANDROID)
|
||||
target_link_libraries(sentencepiece log)
|
||||
target_link_libraries(sentencepiece_train log)
|
||||
endif()
|
||||
endif()
|
||||
target_link_libraries(sentencepiece-interface
|
||||
sentencepiece-static-android
|
||||
android
|
||||
log
|
||||
)
|
||||
|
||||
add_library(sentencepiece-static STATIC ${SPM_SRCS})
|
||||
add_library(sentencepiece_train-static STATIC ${SPM_TRAIN_SRCS})
|
||||
target_link_libraries(cache_container_native
|
||||
android
|
||||
log
|
||||
)
|
||||
|
||||
target_link_libraries(sentencepiece-static INTERFACE ${SPM_LIBS})
|
||||
target_link_libraries(sentencepiece_train-static INTERFACE sentencepiece-static ${SPM_LIBS})
|
||||
target_link_libraries(bergamot_translator_interface
|
||||
bergamot-translator
|
||||
android
|
||||
log
|
||||
)
|
||||
|
||||
if (SPM_ENABLE_SHARED)
|
||||
target_link_libraries(sentencepiece ${SPM_LIBS})
|
||||
target_link_libraries(sentencepiece_train ${SPM_LIBS} sentencepiece)
|
||||
set(SPM_INSTALLTARGETS sentencepiece sentencepiece_train sentencepiece-static sentencepiece_train-static)
|
||||
set_target_properties(sentencepiece sentencepiece_train PROPERTIES SOVERSION 0 VERSION 0.0.0)
|
||||
set_target_properties(sentencepiece PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS YES)
|
||||
set_target_properties(sentencepiece_train PROPERTIES WINDOWS_EXPORT_ALL_SYMBOLS YES)
|
||||
if (MSVC)
|
||||
set_target_properties(sentencepiece PROPERTIES IMPORT_SUFFIX "_import.lib")
|
||||
set_target_properties(sentencepiece_train PROPERTIES IMPORT_SUFFIX "_import.lib")
|
||||
elseif (MINGW)
|
||||
set_target_properties(sentencepiece PROPERTIES IMPORT_SUFFIX ".dll.a")
|
||||
set_target_properties(sentencepiece_train PROPERTIES IMPORT_SUFFIX ".dll.a")
|
||||
endif()
|
||||
else()
|
||||
add_library(sentencepiece ALIAS sentencepiece-static)
|
||||
add_library(sentencepiece_train ALIAS sentencepiece_train-static)
|
||||
set(SPM_INSTALLTARGETS sentencepiece-static sentencepiece_train-static)
|
||||
endif()
|
||||
|
||||
set_target_properties(sentencepiece-static PROPERTIES OUTPUT_NAME "sentencepiece")
|
||||
set_target_properties(sentencepiece_train-static PROPERTIES OUTPUT_NAME "sentencepiece_train")
|
||||
|
||||
if (NOT MSVC)
|
||||
if (SPM_COVERAGE)
|
||||
set(CMAKE_CXX_FLAGS "-O0 -Wall -fPIC -coverage ${CMAKE_CXX_FLAGS}")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "-O3 -Wall -fPIC ${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
if (SPM_ENABLE_TENSORFLOW_SHARED)
|
||||
add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0)
|
||||
endif()
|
||||
if (SPM_NO_THREADLOCAL)
|
||||
add_definitions(-DSPM_NO_THREADLOCAL=1)
|
||||
add_definitions(-DGOOGLE_PROTOBUF_NO_THREADLOCAL=1)
|
||||
endif()
|
||||
set_source_files_properties(
|
||||
sentencepiece.pb.cc sentencepiece_model.pb.cc
|
||||
PROPERTIES COMPILE_FLAGS "-Wno-misleading-indentation")
|
||||
set_source_files_properties(${SPM_TEST_SRCS}
|
||||
PROPERTIES COMPILE_FLAGS "-Wno-sign-compare")
|
||||
if (SPM_ENABLE_SHARED)
|
||||
set_property(TARGET sentencepiece APPEND_STRING PROPERTY COMPILE_FLAGS " -DPIC")
|
||||
set_property(TARGET sentencepiece_train APPEND_STRING PROPERTY COMPILE_FLAGS " -DPIC")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_executable(spm_encode spm_encode_main.cc)
|
||||
add_executable(spm_decode spm_decode_main.cc)
|
||||
add_executable(spm_normalize spm_normalize_main.cc)
|
||||
add_executable(spm_train spm_train_main.cc)
|
||||
add_executable(spm_export_vocab spm_export_vocab_main.cc)
|
||||
|
||||
target_link_libraries(spm_encode sentencepiece)
|
||||
target_link_libraries(spm_decode sentencepiece)
|
||||
target_link_libraries(spm_normalize sentencepiece sentencepiece_train)
|
||||
target_link_libraries(spm_train sentencepiece sentencepiece_train)
|
||||
target_link_libraries(spm_export_vocab sentencepiece)
|
||||
|
||||
if (SPM_ENABLE_NFKC_COMPILE)
|
||||
add_executable(compile_charsmap compile_charsmap_main.cc)
|
||||
target_link_libraries(compile_charsmap sentencepiece sentencepiece_train)
|
||||
endif()
|
||||
|
||||
list(APPEND SPM_INSTALLTARGETS
|
||||
spm_encode spm_decode spm_normalize spm_train spm_export_vocab)
|
||||
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
install(TARGETS ${SPM_INSTALLTARGETS}
|
||||
BUNDLE DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
else()
|
||||
install(TARGETS ${SPM_INSTALLTARGETS}
|
||||
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
||||
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
endif()
|
||||
|
||||
install(FILES sentencepiece_trainer.h sentencepiece_processor.h
|
||||
DESTINATION ${CMAKE_INSTALL_INCDIR})
|
||||
if (NOT SPM_PROTOBUF_PROVIDER STREQUAL "internal")
|
||||
install(FILES ${SPM_PROTO_HDRS} DESTINATION ${CMAKE_INSTALL_INCDIR})
|
||||
endif()
|
||||
|
||||
file(TO_NATIVE_PATH "${PROJECT_SOURCE_DIR}/data" data_dir)
|
||||
|
||||
if (SPM_BUILD_TEST OR SPM_COVERAGE)
|
||||
enable_testing()
|
||||
add_executable(spm_test test_main.cc ${SPM_TEST_SRCS})
|
||||
|
||||
if (SPM_COVERAGE)
|
||||
target_link_libraries(spm_test sentencepiece sentencepiece_train "-lgcov")
|
||||
else()
|
||||
target_link_libraries(spm_test sentencepiece sentencepiece_train)
|
||||
endif()
|
||||
|
||||
set(MEMORYCHECK_COMMAND_OPTIONS "--leak-check=full --show-leak-kinds=definite,possible --error-exitcode=1")
|
||||
find_program(CTEST_MEMORYCHECK_COMMAND NAMES valgrind)
|
||||
include(Dart)
|
||||
|
||||
add_test(NAME sentencepiece_test
|
||||
COMMAND $<TARGET_FILE:spm_test> --test_srcdir=${data_dir})
|
||||
endif()
|
||||
|
||||
#add_library(${CMAKE_PROJECT_NAME} SHARED
|
||||
# List C/C++ source files with relative paths to this CMakeLists.txt.
|
||||
#SentencePieceProcessorInterface.cpp)
|
||||
|
||||
target_link_libraries(${CMAKE_PROJECT_NAME}
|
||||
# List libraries link to the target library
|
||||
android
|
||||
log)
|
||||
|
||||
if (SPM_COVERAGE)
|
||||
add_custom_target(coverage
|
||||
COMMAND mkdir -p coverage
|
||||
COMMAND $<TARGET_FILE:spm_test> --test_srcdir=${data_dir}
|
||||
COMMAND lcov -c -d . -o coverage.info
|
||||
COMMAND lcov --remove coverage.info "include*" "/c++" "_test*" "testharness*" "third_party*" ".pb.*" -o coverage.info
|
||||
COMMAND mkdir -p lcov_html
|
||||
COMMAND genhtml -o lcov_html coverage.info)
|
||||
add_dependencies(coverage spm_test)
|
||||
endif()
|
||||
|
||||
if (CMAKE_SYSTEM_NAME STREQUAL "iOS")
|
||||
set_xcode_property(spm_encode PRODUCT_BUNDLE_IDENTIFIER "SentencePiece" All)
|
||||
set_xcode_property(spm_decode PRODUCT_BUNDLE_IDENTIFIER "SentencePiece" All)
|
||||
set_xcode_property(spm_normalize PRODUCT_BUNDLE_IDENTIFIER "SentencePiece" All)
|
||||
set_xcode_property(spm_train PRODUCT_BUNDLE_IDENTIFIER "SentencePiece" All)
|
||||
set_xcode_property(spm_export_vocab PRODUCT_BUNDLE_IDENTIFIER "SentencePiece" All)
|
||||
endif()
|
||||
|
||||
target_compile_options(sentencepiece PRIVATE
|
||||
"$<$<CONFIG:DEBUG>:-O0>"
|
||||
)
|
||||
|
||||
#add_executable(sentenceprocessorinterface SentencePieceProcessorInterface.cpp)
|
||||
target_link_libraries(bergamot_translator_interface_old
|
||||
bergamot-translator
|
||||
android
|
||||
log
|
||||
)
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
//
|
||||
// Created by luca on 03/03/24.
|
||||
//
|
||||
#include <jni.h>
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <android/asset_manager_jni.h>
|
||||
#include <android/log.h>
|
||||
#include <android/sharedmem.h>
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
#include <android/NeuralNetworks.h>
|
||||
#include <android/NeuralNetworksTypes.h>
|
||||
|
||||
jstring stringToJstring2(JNIEnv* env, std::string str);
|
||||
|
||||
extern "C" jstring
|
||||
Java_com_bluetooth_communicatorexample_nnapi_NNAPITest_getAvailableDevices(JNIEnv* env, jclass clazz){
|
||||
/*uint32_t n_device = 0;
|
||||
std::string ret;
|
||||
ANeuralNetworks_getDeviceCount(&n_device);
|
||||
for(int i=0; i < n_device; i++){
|
||||
ANeuralNetworksDevice * device;
|
||||
ANeuralNetworks_getDevice(i,&device);
|
||||
const char *name = nullptr;
|
||||
ANeuralNetworksDevice_getName(device, &name);
|
||||
int32_t version = 0;
|
||||
ANeuralNetworksDevice_getType(device, &version);
|
||||
}*/
|
||||
return stringToJstring2(env,"");
|
||||
}
|
||||
|
||||
|
||||
jstring stringToJstring2(JNIEnv* env, std::string str){
|
||||
const char* chars = str.data();
|
||||
return env->NewStringUTF(chars);
|
||||
}
|
||||
|
|
@ -17,7 +17,7 @@
|
|||
#include <jni.h>
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "libs/sentencepiece/src/sentencepiece_processor.h"
|
||||
|
||||
using namespace sentencepiece;
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ Java_nie_translator_rtranslator_voice_1translation_neural_1networks_translation_
|
|||
std::vector<int> ids(1024,0);
|
||||
std::string string = jstringToString(env,text);
|
||||
(*proc).Encode(string, &ids);
|
||||
return intVectorTojintArray(env,ids);;
|
||||
return intVectorTojintArray(env,ids);
|
||||
}
|
||||
|
||||
extern "C" jint
|
||||
|
|
|
|||
|
|
@ -1,205 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "bpe_model.h"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <random>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "freelist.h"
|
||||
#include "third_party/absl/container/flat_hash_map.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace bpe {
|
||||
|
||||
Model::Model(const ModelProto &model_proto) {
|
||||
model_proto_ = &model_proto;
|
||||
InitializePieces();
|
||||
}
|
||||
|
||||
Model::~Model() {}
|
||||
|
||||
std::vector<std::pair<absl::string_view, int>> Model::SampleEncode(
|
||||
absl::string_view normalized, float alpha) const {
|
||||
if (!status().ok() || normalized.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
struct SymbolPair {
|
||||
int left; // left index of this pair
|
||||
int right; // right index of this pair
|
||||
float score; // score of this pair. large is better.
|
||||
size_t size; // length of this piece
|
||||
};
|
||||
|
||||
class SymbolPairComparator {
|
||||
public:
|
||||
const bool operator()(SymbolPair *h1, SymbolPair *h2) {
|
||||
return (h1->score < h2->score ||
|
||||
(h1->score == h2->score && h1->left > h2->left));
|
||||
}
|
||||
};
|
||||
|
||||
struct Symbol {
|
||||
int prev; // prev index of this symbol. -1 for BOS.
|
||||
int next; // next index of tihs symbol. -1 for EOS.
|
||||
bool freeze; // this symbol is never be merged.
|
||||
absl::string_view piece;
|
||||
};
|
||||
|
||||
using Agenda = std::priority_queue<SymbolPair *, std::vector<SymbolPair *>,
|
||||
SymbolPairComparator>;
|
||||
Agenda agenda;
|
||||
std::vector<Symbol> symbols;
|
||||
symbols.reserve(normalized.size());
|
||||
|
||||
// Reverse merge rules.
|
||||
// key: merged symbol, value: pair of original symbols.
|
||||
absl::flat_hash_map<absl::string_view,
|
||||
std::pair<absl::string_view, absl::string_view>>
|
||||
rev_merge;
|
||||
|
||||
// Pre-allocates SymbolPair for efficiency.
|
||||
constexpr size_t kPreallocateSymbolPairSize = 256;
|
||||
model::FreeList<SymbolPair> symbol_pair_allocator(kPreallocateSymbolPairSize);
|
||||
|
||||
// Lookup new symbol pair at [left, right] and inserts it to agenda.
|
||||
auto MaybeAddNewSymbolPair = [this, &symbol_pair_allocator, &symbols, &agenda,
|
||||
&rev_merge](int left, int right) {
|
||||
if (left == -1 || right == -1 || symbols[left].freeze ||
|
||||
symbols[right].freeze)
|
||||
return;
|
||||
const absl::string_view piece(
|
||||
symbols[left].piece.data(),
|
||||
symbols[left].piece.size() + symbols[right].piece.size());
|
||||
const auto it = pieces_.find(piece);
|
||||
if (it == pieces_.end()) {
|
||||
return;
|
||||
}
|
||||
auto *h = symbol_pair_allocator.Allocate();
|
||||
h->left = left;
|
||||
h->right = right;
|
||||
h->score = GetScore(it->second);
|
||||
h->size = piece.size();
|
||||
agenda.push(h);
|
||||
|
||||
// Makes `rev_merge` for resegmentation.
|
||||
if (IsUnusedInlined(it->second)) {
|
||||
rev_merge[piece] =
|
||||
std::make_pair(symbols[left].piece, symbols[right].piece);
|
||||
}
|
||||
};
|
||||
|
||||
// Splits the input into character sequence
|
||||
int index = 0;
|
||||
while (!normalized.empty()) {
|
||||
Symbol s;
|
||||
const int mblen = matcher_->PrefixMatch(normalized, &s.freeze);
|
||||
s.piece = absl::string_view(normalized.data(), mblen);
|
||||
s.prev = index == 0 ? -1 : index - 1;
|
||||
normalized.remove_prefix(mblen);
|
||||
s.next = normalized.empty() ? -1 : index + 1;
|
||||
++index;
|
||||
symbols.emplace_back(s);
|
||||
}
|
||||
|
||||
if (symbols.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Lookup all bigrams.
|
||||
for (size_t i = 1; i < symbols.size(); ++i) {
|
||||
MaybeAddNewSymbolPair(i - 1, i);
|
||||
}
|
||||
|
||||
// BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
|
||||
std::mt19937 *rand_gen = nullptr;
|
||||
auto skip_merge = [&]() {
|
||||
if (alpha <= 0.0) return false;
|
||||
if (alpha >= 1.0) return true;
|
||||
if (rand_gen == nullptr) rand_gen = random::GetRandomGenerator();
|
||||
std::uniform_real_distribution<> gen(0.0, 1.0);
|
||||
return gen(*rand_gen) < alpha;
|
||||
};
|
||||
|
||||
// Main loop.
|
||||
while (!agenda.empty()) {
|
||||
SymbolPair *top = agenda.top();
|
||||
agenda.pop();
|
||||
|
||||
// `top` is no longer available.
|
||||
if (symbols[top->left].piece.empty() || symbols[top->right].piece.empty() ||
|
||||
symbols[top->left].piece.size() + symbols[top->right].piece.size() !=
|
||||
top->size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Note that orignal BPE-dropout paper assumes that all merged symbols are
|
||||
// pre computed, but here we randomly skip merge opration inside this loop.
|
||||
// This implemenation is theoretically equivalent to the original one.
|
||||
if (skip_merge()) continue;
|
||||
|
||||
// Replaces symbols with `top` rule.
|
||||
symbols[top->left].piece = absl::string_view(
|
||||
symbols[top->left].piece.data(),
|
||||
symbols[top->left].piece.size() + symbols[top->right].piece.size());
|
||||
|
||||
// Updates prev/next pointers.
|
||||
symbols[top->left].next = symbols[top->right].next;
|
||||
if (symbols[top->right].next >= 0) {
|
||||
symbols[symbols[top->right].next].prev = top->left;
|
||||
}
|
||||
symbols[top->right].piece = absl::string_view("");
|
||||
|
||||
// Adds new symbol pairs which are newly added after symbol replacement.
|
||||
MaybeAddNewSymbolPair(symbols[top->left].prev, top->left);
|
||||
MaybeAddNewSymbolPair(top->left, symbols[top->left].next);
|
||||
}
|
||||
|
||||
std::function<void(absl::string_view, EncodeResult *)> resegment;
|
||||
resegment = [this, &resegment, &rev_merge](absl::string_view w,
|
||||
EncodeResult *output) -> void {
|
||||
const int id = PieceToId(w);
|
||||
if (id == -1 || !IsUnusedInlined(id)) {
|
||||
output->emplace_back(w, id);
|
||||
return;
|
||||
}
|
||||
const auto p = rev_merge.find(w);
|
||||
if (p == rev_merge.end()) {
|
||||
// This block will never be called, as `rev_merge` stores all the
|
||||
// resegmentation info for unused id.
|
||||
output->emplace_back(w, id);
|
||||
return;
|
||||
}
|
||||
// Recursively resegment left and right symbols.
|
||||
resegment(p->second.first, output);
|
||||
resegment(p->second.second, output);
|
||||
};
|
||||
|
||||
EncodeResult output;
|
||||
for (int index = 0; index != -1; index = symbols[index].next) {
|
||||
if (index >= 0 && index < static_cast<int>(symbols.size())) {
|
||||
resegment(symbols[index].piece, &output);
|
||||
}
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
} // namespace bpe
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef BPE_MODEL_H_
|
||||
#define BPE_MODEL_H_
|
||||
|
||||
#include "model_interface.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace bpe {
|
||||
|
||||
// Segmentation model with BPE (Byte Pair Encoding)
|
||||
// Details:
|
||||
// Neural Machine Translation of Rare Words with Subword Units
|
||||
// https://arxiv.org/abs/1508.07909
|
||||
//
|
||||
// https://en.wikipedia.org/wiki/Byte_pair_encoding
|
||||
class Model : public ModelInterface {
|
||||
public:
|
||||
explicit Model(const ModelProto &model_proto);
|
||||
~Model() override;
|
||||
|
||||
EncodeResult Encode(absl::string_view normalized) const override {
|
||||
return SampleEncode(normalized, 0.0);
|
||||
}
|
||||
|
||||
// Sampling with BPE-dropout: https://arxiv.org/pdf/1910.13267.pdf
|
||||
// `alpha` is dropout probability in BPE-dropout paper.
|
||||
// Skips merge operation with `alpha` probability.
|
||||
// When alpha <= 0.0, no sampling is performed.
|
||||
EncodeResult SampleEncode(absl::string_view normalized,
|
||||
float alpha) const override;
|
||||
|
||||
bool IsSampleEncodeAvailable() const override { return true; }
|
||||
|
||||
bool IsNBestEncodeAvailable() const override { return false; }
|
||||
};
|
||||
} // namespace bpe
|
||||
} // namespace sentencepiece
|
||||
#endif // BPE_MODEL_H_
|
||||
|
|
@ -1,299 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
|
||||
#include "bpe_model.h"
|
||||
#include "model_interface.h"
|
||||
#include "testharness.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace bpe {
|
||||
namespace {
|
||||
|
||||
ModelProto MakeBaseModelProto() {
|
||||
ModelProto model_proto;
|
||||
auto *sp1 = model_proto.add_pieces();
|
||||
auto *sp2 = model_proto.add_pieces();
|
||||
auto *sp3 = model_proto.add_pieces();
|
||||
|
||||
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
|
||||
sp1->set_piece("<unk>");
|
||||
sp2->set_type(ModelProto::SentencePiece::CONTROL);
|
||||
sp2->set_piece("<s>");
|
||||
sp3->set_type(ModelProto::SentencePiece::CONTROL);
|
||||
sp3->set_piece("</s>");
|
||||
|
||||
return model_proto;
|
||||
}
|
||||
|
||||
void AddPiece(ModelProto *model_proto, const std::string &piece,
|
||||
float score = 0.0) {
|
||||
auto *sp = model_proto->add_pieces();
|
||||
sp->set_piece(piece);
|
||||
sp->set_score(score);
|
||||
}
|
||||
|
||||
TEST(BPEModelTest, EncodeTest) {
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
|
||||
AddPiece(&model_proto, "ab", 0.0); // 3
|
||||
AddPiece(&model_proto, "cd", -0.1); // 4
|
||||
AddPiece(&model_proto, "abc", -0.2); // 5
|
||||
AddPiece(&model_proto, "a", -0.3); // 6
|
||||
AddPiece(&model_proto, "b", -0.4); // 7
|
||||
AddPiece(&model_proto, "c", -0.5); // 8
|
||||
AddPiece(&model_proto, "ABC", -0.5); // 9
|
||||
AddPiece(&model_proto, "abcdabcd", -0.5); // 10
|
||||
AddPiece(&model_proto, "q", -0.5); // 11
|
||||
AddPiece(&model_proto, "r", -0.5); // 12
|
||||
AddPiece(&model_proto, "qr", -0.5); // 13
|
||||
model_proto.mutable_pieces(9)->set_type( // ABC
|
||||
ModelProto::SentencePiece::USER_DEFINED);
|
||||
model_proto.mutable_pieces(10)->set_type( // abcdabcd
|
||||
ModelProto::SentencePiece::USER_DEFINED);
|
||||
model_proto.mutable_pieces(11)->set_type( // q
|
||||
ModelProto::SentencePiece::USER_DEFINED);
|
||||
model_proto.mutable_pieces(12)->set_type( // r
|
||||
ModelProto::SentencePiece::USER_DEFINED);
|
||||
|
||||
const Model model(model_proto);
|
||||
|
||||
EncodeResult result;
|
||||
|
||||
result = model.Encode("");
|
||||
EXPECT_TRUE(result.empty());
|
||||
|
||||
result = model.Encode("abc");
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ("abc", result[0].first);
|
||||
|
||||
result = model.Encode("AB");
|
||||
EXPECT_EQ(2, result.size());
|
||||
EXPECT_EQ("A", result[0].first);
|
||||
EXPECT_EQ("B", result[1].first);
|
||||
|
||||
result = model.Encode("abcd");
|
||||
EXPECT_EQ(2, result.size());
|
||||
EXPECT_EQ("ab", result[0].first);
|
||||
EXPECT_EQ("cd", result[1].first);
|
||||
|
||||
result = model.Encode("abcc");
|
||||
EXPECT_EQ(2, result.size());
|
||||
EXPECT_EQ("abc", result[0].first);
|
||||
EXPECT_EQ("c", result[1].first);
|
||||
|
||||
result = model.Encode("xabcabaabcdd");
|
||||
EXPECT_EQ(7, result.size());
|
||||
EXPECT_EQ("x", result[0].first);
|
||||
EXPECT_EQ("abc", result[1].first);
|
||||
EXPECT_EQ("ab", result[2].first);
|
||||
EXPECT_EQ("a", result[3].first);
|
||||
EXPECT_EQ("ab", result[4].first);
|
||||
EXPECT_EQ("cd", result[5].first);
|
||||
EXPECT_EQ("d", result[6].first);
|
||||
|
||||
// all unknown.
|
||||
result = model.Encode("xyz東京");
|
||||
EXPECT_EQ(5, result.size());
|
||||
EXPECT_EQ("x", result[0].first);
|
||||
EXPECT_EQ("y", result[1].first);
|
||||
EXPECT_EQ("z", result[2].first);
|
||||
EXPECT_EQ("東", result[3].first);
|
||||
EXPECT_EQ("京", result[4].first);
|
||||
|
||||
// User defined
|
||||
result = model.Encode("ABC");
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ("ABC", result[0].first);
|
||||
|
||||
result = model.Encode("abABCcd");
|
||||
EXPECT_EQ(3, result.size());
|
||||
EXPECT_EQ("ab", result[0].first);
|
||||
EXPECT_EQ("ABC", result[1].first);
|
||||
EXPECT_EQ("cd", result[2].first);
|
||||
|
||||
// middle "abcdabcd" is user defined.
|
||||
result = model.Encode("ababcdabcdcd");
|
||||
EXPECT_EQ(3, result.size());
|
||||
EXPECT_EQ("ab", result[0].first);
|
||||
EXPECT_EQ("abcdabcd", result[1].first);
|
||||
EXPECT_EQ("cd", result[2].first);
|
||||
|
||||
result = model.Encode("abqrcd");
|
||||
EXPECT_EQ(4, result.size());
|
||||
EXPECT_EQ("ab", result[0].first);
|
||||
EXPECT_EQ("q", result[1].first);
|
||||
EXPECT_EQ("r", result[2].first);
|
||||
EXPECT_EQ("cd", result[3].first);
|
||||
}
|
||||
|
||||
TEST(BPEModelTest, EncodeAmbiguousTest) {
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
|
||||
AddPiece(&model_proto, "aa", -0.1);
|
||||
AddPiece(&model_proto, "bb", -0.2);
|
||||
AddPiece(&model_proto, "ab", -0.3);
|
||||
AddPiece(&model_proto, "a", -0.4);
|
||||
AddPiece(&model_proto, "b", -0.5);
|
||||
|
||||
const Model model(model_proto);
|
||||
|
||||
EncodeResult result;
|
||||
|
||||
// leftmost symbols are merged first.
|
||||
result = model.Encode("aaa");
|
||||
EXPECT_EQ(2, result.size());
|
||||
EXPECT_EQ("aa", result[0].first);
|
||||
EXPECT_EQ("a", result[1].first);
|
||||
|
||||
// "bb" is replaced earlier than "ab".
|
||||
result = model.Encode("aabb");
|
||||
EXPECT_EQ(2, result.size());
|
||||
EXPECT_EQ("aa", result[0].first);
|
||||
EXPECT_EQ("bb", result[1].first);
|
||||
|
||||
// "bb" is replaced earlier than "ab".
|
||||
result = model.Encode("aaabbb");
|
||||
EXPECT_EQ(4, result.size());
|
||||
EXPECT_EQ("aa", result[0].first);
|
||||
EXPECT_EQ("a", result[1].first);
|
||||
EXPECT_EQ("bb", result[2].first);
|
||||
EXPECT_EQ("b", result[3].first);
|
||||
|
||||
result = model.Encode("aaaba");
|
||||
EXPECT_EQ(3, result.size());
|
||||
EXPECT_EQ("aa", result[0].first);
|
||||
EXPECT_EQ("ab", result[1].first);
|
||||
EXPECT_EQ("a", result[2].first);
|
||||
|
||||
// makes a broken utf-8
|
||||
const std::string broken_utf8 = std::string("あ").substr(0, 1);
|
||||
result = model.Encode(broken_utf8);
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(broken_utf8, result[0].first);
|
||||
}
|
||||
|
||||
TEST(BPEModelTest, NotSupportedTest) {
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
const Model model(model_proto);
|
||||
EXPECT_EQ(NBestEncodeResult(), model.NBestEncode("test", 10));
|
||||
}
|
||||
|
||||
TEST(BPEModelTest, EncodeWithUnusedTest) {
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
|
||||
AddPiece(&model_proto, "abcd", 10.0); // 3
|
||||
AddPiece(&model_proto, "abc", 5.0); // 4
|
||||
AddPiece(&model_proto, "ab", 2.0); // 5
|
||||
AddPiece(&model_proto, "cd", 1.0); // 6
|
||||
AddPiece(&model_proto, "a", 0.0); // 7
|
||||
AddPiece(&model_proto, "b", 0.0); // 8
|
||||
AddPiece(&model_proto, "c", 0.0); // 9
|
||||
AddPiece(&model_proto, "d", 0.0); // 10
|
||||
|
||||
// No unused.
|
||||
{
|
||||
const Model model(model_proto);
|
||||
const auto result = model.Encode("abcd");
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ("abcd", result[0].first);
|
||||
}
|
||||
|
||||
{
|
||||
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
|
||||
const Model model(model_proto);
|
||||
const auto result = model.Encode("abcd");
|
||||
EXPECT_EQ(2, result.size());
|
||||
EXPECT_EQ("abc", result[0].first);
|
||||
EXPECT_EQ("d", result[1].first);
|
||||
}
|
||||
|
||||
{
|
||||
// The parent rule "abc" is still alive even if the child "ab" is unused.
|
||||
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
|
||||
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::UNUSED);
|
||||
const Model model(model_proto);
|
||||
const auto result = model.Encode("abcd");
|
||||
EXPECT_EQ(2, result.size());
|
||||
EXPECT_EQ("abc", result[0].first);
|
||||
EXPECT_EQ("d", result[1].first);
|
||||
}
|
||||
|
||||
{
|
||||
// This is tricky case. Even though "cd" is alive, it is not used, as
|
||||
// it is not merged during the segmentation step.
|
||||
// Segmentation: a|b|c|d => ab|c|d| => abc|d => abcd
|
||||
// Resegmentation: abcd => abc|d => ab|c|d. ("abcd", "abc" are unsued)
|
||||
model_proto.mutable_pieces(3)->set_type(ModelProto::SentencePiece::UNUSED);
|
||||
model_proto.mutable_pieces(4)->set_type(ModelProto::SentencePiece::UNUSED);
|
||||
model_proto.mutable_pieces(5)->set_type(ModelProto::SentencePiece::NORMAL);
|
||||
const Model model(model_proto);
|
||||
const auto result = model.Encode("abcd");
|
||||
EXPECT_EQ(3, result.size());
|
||||
EXPECT_EQ("ab", result[0].first);
|
||||
EXPECT_EQ("c", result[1].first);
|
||||
EXPECT_EQ("d", result[2].first);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(SampleModelTest, EncodeTest) {
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
|
||||
AddPiece(&model_proto, "ab", 0.0);
|
||||
AddPiece(&model_proto, "cd", -0.1);
|
||||
AddPiece(&model_proto, "abc", -0.2);
|
||||
AddPiece(&model_proto, "abcd", -0.3);
|
||||
|
||||
// No regularization
|
||||
{
|
||||
const Model model(model_proto);
|
||||
const auto result = model.Encode("abcd");
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ("abcd", result[0].first);
|
||||
}
|
||||
|
||||
{
|
||||
auto get_tokens = [](const EncodeResult &result) {
|
||||
std::string out;
|
||||
for (const auto &r : result) {
|
||||
if (!result.empty()) out += ' ';
|
||||
out += std::string(r.first);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
const Model model(model_proto);
|
||||
const std::vector<double> kAlpha = {0.0, 0.1, 0.5, 0.7, 0.9};
|
||||
for (const auto alpha : kAlpha) {
|
||||
constexpr int kTrial = 100000;
|
||||
std::map<std::string, int> freq;
|
||||
for (int n = 0; n < kTrial; ++n)
|
||||
freq[get_tokens(
|
||||
model.SampleEncode("abcd", static_cast<float>(alpha)))]++;
|
||||
int num = 0;
|
||||
if (alpha == 0.0)
|
||||
EXPECT_EQ(1, freq.size());
|
||||
else
|
||||
EXPECT_GT(freq.size(), 1);
|
||||
for (const auto &it : freq) num += it.second;
|
||||
EXPECT_EQ(num, kTrial);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace bpe
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,323 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "bpe_model_trainer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "pretokenizer_for_training.h"
|
||||
#include "third_party/absl/container/flat_hash_set.h"
|
||||
#include "third_party/absl/strings/str_join.h"
|
||||
#include "third_party/absl/strings/str_replace.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace bpe {
|
||||
|
||||
std::string Trainer::Symbol::ToString() const {
|
||||
return string_util::UnicodeTextToUTF8(chars);
|
||||
}
|
||||
|
||||
Trainer::Symbol *Trainer::GetCharSymbol(char32 c) {
|
||||
const uint64 freq = port::FindWithDefault(required_chars_, c, 1);
|
||||
CHECK_GT(freq, 0);
|
||||
const auto it = symbols_cache_.find(c);
|
||||
if (it != symbols_cache_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
Symbol *s = new Symbol;
|
||||
allocated_.push_back(s);
|
||||
s->is_unk = (kUNKChar == c);
|
||||
s->fp = c;
|
||||
s->chars.push_back(c);
|
||||
s->freq = freq;
|
||||
port::InsertOrDie(&symbols_cache_, s->fp, s);
|
||||
return s;
|
||||
}
|
||||
|
||||
Trainer::Symbol *Trainer::GetPairSymbol(const Symbol *left,
|
||||
const Symbol *right) {
|
||||
if (left == nullptr || right == nullptr || left->is_unk || right->is_unk) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const uint64 fp = port::FingerprintCat(left->fp, right->fp);
|
||||
const auto it = symbols_cache_.find(fp);
|
||||
if (it != symbols_cache_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
|
||||
CHECK(!left->chars.empty());
|
||||
CHECK(!right->chars.empty());
|
||||
string_util::UnicodeText ut;
|
||||
for (const char32 c : left->chars) ut.push_back(c);
|
||||
for (const char32 c : right->chars) ut.push_back(c);
|
||||
|
||||
// Do not make an invalid piece.
|
||||
if (!IsValidSentencePiece(ut)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Symbol *s = new Symbol;
|
||||
allocated_.push_back(s);
|
||||
s->fp = fp;
|
||||
s->left = left;
|
||||
s->right = right;
|
||||
s->chars = ut;
|
||||
port::InsertOrDie(&symbols_cache_, s->fp, s);
|
||||
return s;
|
||||
}
|
||||
|
||||
void Trainer::ComputeFreq(Symbol *symbol) const {
|
||||
if (symbol->freq > 0) { // if freq == 0, re-computation is required.
|
||||
return;
|
||||
}
|
||||
CHECK_EQ(0, symbol->freq);
|
||||
for (auto it = symbol->positions.begin(); it != symbol->positions.end();) {
|
||||
const Position pos = DecodePos(*it);
|
||||
// symbols_[sid][left] and symbols_[sid]right] must store
|
||||
// the same symbols in symbol->left and symbols->right.
|
||||
if (symbol->left != symbols_[pos.sid][pos.left] ||
|
||||
symbol->right != symbols_[pos.sid][pos.right]) {
|
||||
it = symbol->positions.erase(it);
|
||||
} else {
|
||||
symbol->freq += sentences_[pos.sid].second;
|
||||
++it;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int Trainer::GetNextIndex(int sid, int index) const {
|
||||
for (size_t i = index + 1; i < symbols_[sid].size(); ++i) {
|
||||
if (symbols_[sid][i] == nullptr) continue;
|
||||
return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int Trainer::GetPrevIndex(int sid, int index) const {
|
||||
for (int i = index - 1; i >= 0; --i) {
|
||||
if (symbols_[sid][i] == nullptr) continue;
|
||||
return i;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
void Trainer::AddNewPair(int sid, int left, int right) {
|
||||
if (left == -1 || right == -1) return;
|
||||
auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]);
|
||||
if (symbol != nullptr) {
|
||||
active_symbols_.insert(symbol);
|
||||
symbol->positions.insert(EncodePos(sid, left, right));
|
||||
}
|
||||
}
|
||||
|
||||
void Trainer::ResetFreq(int sid, int left, int right, const Symbol *best) {
|
||||
if (left == -1 || right == -1) return;
|
||||
auto *symbol = GetPairSymbol(symbols_[sid][left], symbols_[sid][right]);
|
||||
if (symbol != nullptr && symbol != best) {
|
||||
symbol->freq = 0;
|
||||
}
|
||||
}
|
||||
|
||||
void Trainer::UpdateActiveSymbols() {
|
||||
std::vector<Symbol *> symbols;
|
||||
for (auto &it : symbols_cache_) {
|
||||
Symbol *symbol = it.second;
|
||||
if (symbol->IsBigram()) {
|
||||
ComputeFreq(symbol);
|
||||
symbols.push_back(symbol);
|
||||
}
|
||||
}
|
||||
|
||||
// At least kMinActiveSymbolsSize symbols must be in |active_symbols_|.
|
||||
constexpr int kMinActiveSymbolsSize = 1000;
|
||||
|
||||
// Keeps top 5% frequent symbols.
|
||||
constexpr float kTopFrequentRatio = 0.05;
|
||||
const int size =
|
||||
std::min<int>(std::max<int>(kMinActiveSymbolsSize,
|
||||
symbols_cache_.size() * kTopFrequentRatio),
|
||||
symbols.size());
|
||||
|
||||
std::partial_sort(symbols.begin(), symbols.begin() + size, symbols.end(),
|
||||
[](Symbol *s1, Symbol *s2) { return s1->freq > s2->freq; });
|
||||
LOG(INFO) << "Updating active symbols. max_freq=" << symbols[0]->freq
|
||||
<< " min_freq=" << symbols[size - 1]->freq;
|
||||
|
||||
active_symbols_.clear();
|
||||
active_symbols_.insert(symbols.begin(), symbols.begin() + size);
|
||||
}
|
||||
|
||||
util::Status Trainer::Train() {
|
||||
RETURN_IF_ERROR(status());
|
||||
|
||||
CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces());
|
||||
CHECK_EQ_OR_RETURN(TrainerSpec::BPE, trainer_spec_.model_type());
|
||||
|
||||
symbols_.clear();
|
||||
allocated_.clear();
|
||||
symbols_cache_.clear();
|
||||
active_symbols_.clear();
|
||||
|
||||
// Load all sentences
|
||||
RETURN_IF_ERROR(LoadSentences());
|
||||
|
||||
if (trainer_spec_.split_by_whitespace()) {
|
||||
SplitSentencesByWhitespace();
|
||||
}
|
||||
|
||||
// Pretokenizer applied only in training time.
|
||||
// Pretokenizer is used as a constraint of piece extractions.
|
||||
const auto *pretokenizer = SentencePieceTrainer::GetPretokenizerForTraining();
|
||||
|
||||
if (pretokenizer || !trainer_spec_.pretokenization_delimiter().empty()) {
|
||||
absl::string_view delimiter = trainer_spec_.pretokenization_delimiter();
|
||||
LOG(INFO) << "Preprocessing with pretokenizer...";
|
||||
for (auto &w : sentences_) {
|
||||
if (pretokenizer) {
|
||||
w.first = absl::StrJoin(pretokenizer->PreTokenize(w.first),
|
||||
TrainerInterface::kUPPBoundaryStr);
|
||||
} else if (!delimiter.empty()) {
|
||||
w.first = absl::StrReplaceAll(
|
||||
w.first, {{delimiter, TrainerInterface::kUPPBoundaryStr}});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initializes symbols_. symbols_[sid][i] stores an unary symbol.
|
||||
symbols_.resize(sentences_.size());
|
||||
for (size_t i = 0; i < sentences_.size(); ++i) {
|
||||
for (const char32 c : string_util::UTF8ToUnicodeText(sentences_[i].first)) {
|
||||
symbols_[i].push_back(GetCharSymbol(c));
|
||||
}
|
||||
}
|
||||
|
||||
// Makes all bigram symbols.
|
||||
for (size_t sid = 0; sid < symbols_.size(); ++sid) {
|
||||
for (size_t i = 1; i < symbols_[sid].size(); ++i) {
|
||||
AddNewPair(sid, i - 1, i);
|
||||
}
|
||||
}
|
||||
|
||||
const int vocab_size =
|
||||
trainer_spec_.vocab_size() - meta_pieces_.size() - required_chars_.size();
|
||||
CHECK_GE_OR_RETURN(vocab_size, 0);
|
||||
|
||||
// We may see duplicated pieces that are extracted with different path.
|
||||
// In real segmentation phase, we can consider them as one symbol.
|
||||
// e.g., "aaa" => "aa" + "a" or "a" + "aa".
|
||||
absl::flat_hash_set<std::string> dup;
|
||||
|
||||
// Main loop.
|
||||
CHECK_OR_RETURN(final_pieces_.empty());
|
||||
while (final_pieces_.size() < static_cast<size_t>(vocab_size)) {
|
||||
constexpr int kUpdateActiveSymbolsInteval = 100;
|
||||
if (final_pieces_.size() % kUpdateActiveSymbolsInteval == 0) {
|
||||
UpdateActiveSymbols();
|
||||
}
|
||||
|
||||
// Scanning active symbols, finds the best_symbol with highest freq.
|
||||
Symbol *best_symbol = nullptr;
|
||||
for (auto &it : active_symbols_) {
|
||||
Symbol *symbol = it;
|
||||
ComputeFreq(symbol);
|
||||
// If the frequency is the same, take shorter symbol.
|
||||
// if the length is the same, use lexicographical comparison
|
||||
if (best_symbol == nullptr ||
|
||||
(symbol->freq > best_symbol->freq ||
|
||||
(symbol->freq == best_symbol->freq &&
|
||||
(symbol->chars.size() < best_symbol->chars.size() ||
|
||||
(symbol->chars.size() == best_symbol->chars.size() &&
|
||||
symbol->ToString() < best_symbol->ToString()))))) {
|
||||
best_symbol = symbol;
|
||||
}
|
||||
}
|
||||
|
||||
if (best_symbol == nullptr) {
|
||||
LOG(WARNING) << "No valid symbol found";
|
||||
break;
|
||||
}
|
||||
|
||||
if (!dup.insert(best_symbol->ToString()).second) {
|
||||
// Removes best_symbol so it is not selected again.
|
||||
symbols_cache_.erase(best_symbol->fp);
|
||||
active_symbols_.erase(best_symbol);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Stores the best_symbol in the final output.
|
||||
final_pieces_.emplace_back(best_symbol->ToString(),
|
||||
-static_cast<float>(final_pieces_.size()));
|
||||
|
||||
if (final_pieces_.size() % 20 == 0) {
|
||||
LOG(INFO) << "Added: freq=" << best_symbol->freq
|
||||
<< " size=" << final_pieces_.size()
|
||||
<< " all=" << symbols_cache_.size()
|
||||
<< " active=" << active_symbols_.size()
|
||||
<< " piece=" << best_symbol->ToString();
|
||||
}
|
||||
|
||||
// Add new bigrams which are created after symbol replacement.
|
||||
// We do not need to scan all characters, but scan the neighbors in
|
||||
// best_symbol.
|
||||
for (const uint64 &encoded_pos : best_symbol->positions) {
|
||||
const Position pos = DecodePos(encoded_pos);
|
||||
|
||||
if (symbols_[pos.sid][pos.left] == nullptr) {
|
||||
// left index might be NULL (set in the previous iteration)
|
||||
// when left_symbol == right_symbol.
|
||||
continue;
|
||||
}
|
||||
CHECK_OR_RETURN(symbols_[pos.sid][pos.right]);
|
||||
|
||||
// We have three bigrams [prev, left], [left, right], [right, next],
|
||||
// which are affected with this symbol replacement.
|
||||
const int next = GetNextIndex(pos.sid, pos.right);
|
||||
const int prev = GetPrevIndex(pos.sid, pos.left);
|
||||
|
||||
// Resets the frequencies of bigrams [prev, left] and [right, next].
|
||||
ResetFreq(pos.sid, prev, pos.left, best_symbol);
|
||||
ResetFreq(pos.sid, pos.right, next, best_symbol);
|
||||
|
||||
// Merges two symbols.
|
||||
symbols_[pos.sid][pos.left] = best_symbol;
|
||||
symbols_[pos.sid][pos.right] = nullptr;
|
||||
|
||||
// Makes new symbol bigrams [prev, left] and [left, next].
|
||||
AddNewPair(pos.sid, prev, pos.left);
|
||||
AddNewPair(pos.sid, pos.left, next);
|
||||
}
|
||||
|
||||
// Removes best_symbol so it is not selected again.
|
||||
symbols_cache_.erase(best_symbol->fp);
|
||||
active_symbols_.erase(best_symbol);
|
||||
} // end of main loop
|
||||
|
||||
// Adds required_chars_
|
||||
for (const auto &w : Sorted(required_chars_)) {
|
||||
const Symbol *symbol = GetCharSymbol(w.first);
|
||||
final_pieces_.emplace_back(symbol->ToString(),
|
||||
-static_cast<float>(final_pieces_.size()));
|
||||
}
|
||||
|
||||
port::STLDeleteElements(&allocated_);
|
||||
|
||||
return Save();
|
||||
}
|
||||
} // namespace bpe
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,130 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef BPE_MODEL_TRAINER_H_
|
||||
#define BPE_MODEL_TRAINER_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "third_party/absl/container/btree_set.h"
|
||||
#include "third_party/absl/container/flat_hash_map.h"
|
||||
#include "trainer_interface.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace bpe {
|
||||
|
||||
// Trainer class for BPE model.
|
||||
class Trainer : public TrainerInterface {
|
||||
public:
|
||||
Trainer(const TrainerSpec &trainer_spec,
|
||||
const NormalizerSpec &normalizer_spec,
|
||||
const NormalizerSpec &denormalizer_spec)
|
||||
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec,
|
||||
denormalizer_spec) {}
|
||||
|
||||
util::Status Train() override;
|
||||
|
||||
private:
|
||||
// Symbol represents a character or symbol bigram.
|
||||
struct Symbol {
|
||||
const Symbol *left; // left symbol in bigram
|
||||
const Symbol *right; // right symbol in bigram
|
||||
string_util::UnicodeText chars; // all flattend character sequence
|
||||
bool is_unk; // true if this symbol is unknown.
|
||||
uint64_t fp; // fingerprint of this symbol.
|
||||
uint64_t freq; // frequency of this symbol.
|
||||
|
||||
// Position list. Use set so that we can keep the order of occurrence.
|
||||
// See EncodePos/DecodePos.
|
||||
absl::btree_set<uint64_t> positions;
|
||||
|
||||
bool IsBigram() const { return left != nullptr && right != nullptr; }
|
||||
std::string ToString() const;
|
||||
Symbol() : left(nullptr), right(nullptr), is_unk(false), fp(0), freq(0) {}
|
||||
};
|
||||
|
||||
struct Position {
|
||||
int sid; // sentence id
|
||||
int left; // left symbol index
|
||||
int right; // right symbol index
|
||||
};
|
||||
|
||||
// Encodes sid, left and right bigram index into uint64_t.
|
||||
// Encoded value keeps the order of sid, left and right.
|
||||
static uint64_t EncodePos(int sid, int l, int r) {
|
||||
CHECK_GE(l, 0);
|
||||
CHECK_GE(r, 0);
|
||||
CHECK_LE(l, std::numeric_limits<uint16_t>::max());
|
||||
CHECK_LE(r, std::numeric_limits<uint16_t>::max());
|
||||
const uint64_t n = (static_cast<uint64_t>(sid) << 32) |
|
||||
(static_cast<uint64_t>(l) << 16) | r;
|
||||
return n;
|
||||
}
|
||||
|
||||
// Decodes sid, left and right bigram index from uint64_t.
|
||||
static Position DecodePos(uint64_t n) {
|
||||
Position p;
|
||||
p.sid = n >> 32;
|
||||
p.left = (n >> 16) & 0xffff;
|
||||
p.right = n & 0xffff;
|
||||
return p;
|
||||
}
|
||||
|
||||
// Gets unary (character) symbol from the char code |c|.
|
||||
// The return value is cached.
|
||||
Symbol *GetCharSymbol(char32 c);
|
||||
|
||||
// Gets symbol pair from left/right symbols. The return value is cached.
|
||||
Symbol *GetPairSymbol(const Symbol *left, const Symbol *right);
|
||||
|
||||
// Computes the frequency of |symbol| and update symbol->freq field.
|
||||
void ComputeFreq(Symbol *symbol) const;
|
||||
|
||||
// Returns the valid index before symbols_[sid][index].
|
||||
int GetNextIndex(int sid, int index) const;
|
||||
|
||||
// Returns the valid index after symbols_[sid][index].
|
||||
int GetPrevIndex(int sid, int index) const;
|
||||
|
||||
// Makes a new bigram from [symbols_[sid][left], symbols_[sid][right]] and
|
||||
// Adds it to symbols_cache_ and active_symbols_.
|
||||
void AddNewPair(int sid, int left, int right);
|
||||
|
||||
// Resets the fequency of bigram [symbols_[sid][left] symbols_[sid][right]],
|
||||
// if this bigram is not |best|.
|
||||
void ResetFreq(int sid, int left, int right, const Symbol *best);
|
||||
|
||||
// Updates |active_symbols_| by copying the top 5% frequent symbols in
|
||||
// symbols_cache_.
|
||||
void UpdateActiveSymbols();
|
||||
|
||||
// All unique symbols. Key is a fingerprint of Symbol.
|
||||
absl::flat_hash_map<uint64_t, Symbol *> symbols_cache_;
|
||||
|
||||
// Set of symbols from which we find the best symbol in each iteration.
|
||||
absl::btree_set<Symbol *> active_symbols_;
|
||||
|
||||
// Stores symbols allocated in heap so that we can delete them at onece.
|
||||
std::vector<Symbol *> allocated_;
|
||||
|
||||
// Sentences. symbols_[sid][index] stores a symbol in sentence_[sid][index].
|
||||
std::vector<std::vector<Symbol *>> symbols_;
|
||||
};
|
||||
} // namespace bpe
|
||||
} // namespace sentencepiece
|
||||
#endif // BPE_MODEL_TRAINER_H_
|
||||
|
|
@ -1,139 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "bpe_model_trainer.h"
|
||||
#include "filesystem.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "testharness.h"
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
#include "third_party/absl/strings/str_join.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace bpe {
|
||||
namespace {
|
||||
|
||||
// Space symbol
|
||||
#define WS "\xe2\x96\x81"
|
||||
|
||||
std::string RunTrainer(
|
||||
const std::vector<std::string> &input, int size,
|
||||
const std::vector<std::string> &user_defined_symbols = {}) {
|
||||
const std::string input_file =
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
|
||||
const std::string model_prefix =
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
|
||||
{
|
||||
auto output = filesystem::NewWritableFile(input_file);
|
||||
for (const auto &line : input) {
|
||||
output->WriteLine(line);
|
||||
}
|
||||
}
|
||||
|
||||
TrainerSpec trainer_spec;
|
||||
trainer_spec.set_model_type(TrainerSpec::BPE);
|
||||
trainer_spec.add_input(input_file);
|
||||
trainer_spec.set_vocab_size(size - 3); // remove <unk>, <s>, </s>
|
||||
trainer_spec.set_model_prefix(model_prefix);
|
||||
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("identity");
|
||||
normalizer_spec.set_add_dummy_prefix(false);
|
||||
|
||||
NormalizerSpec denormalizer_spec;
|
||||
|
||||
for (const auto &w : user_defined_symbols) {
|
||||
trainer_spec.add_user_defined_symbols(w);
|
||||
}
|
||||
|
||||
Trainer trainer(trainer_spec, normalizer_spec, denormalizer_spec);
|
||||
EXPECT_TRUE(trainer.Train().ok());
|
||||
|
||||
SentencePieceProcessor processor;
|
||||
EXPECT_TRUE(processor.Load(model_prefix + ".model").ok());
|
||||
|
||||
const auto &model = processor.model_proto();
|
||||
std::vector<std::string> pieces;
|
||||
|
||||
// remove <unk>, <s>, </s>
|
||||
for (int i = 3; i < model.pieces_size(); ++i) {
|
||||
pieces.emplace_back(model.pieces(i).piece());
|
||||
}
|
||||
|
||||
return absl::StrJoin(pieces, " ");
|
||||
}
|
||||
|
||||
TEST(BPETrainerTest, BasicTest) {
|
||||
EXPECT_EQ("ab ra abra ad cad abracad abracadabra ac br a b r c d",
|
||||
RunTrainer({"abracadabra"}, 20));
|
||||
EXPECT_EQ("ap le app apple en in ine pen p e a l n i",
|
||||
RunTrainer({"pen", "pineapple", "apple"}, 20));
|
||||
EXPECT_EQ("he ll llo hello hellohe el lo oh hel ohe e h l o",
|
||||
RunTrainer({"hellohe"}, 20));
|
||||
EXPECT_EQ("app le en in ine pen pine ne pe e l n p i",
|
||||
RunTrainer({"pen", "pineapple", "apple"}, 20, {"app"}));
|
||||
}
|
||||
|
||||
static constexpr char kTestInputData[] = "wagahaiwa_nekodearu.txt";
|
||||
|
||||
TEST(BPETrainerTest, EndToEndTest) {
|
||||
const std::string input =
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData);
|
||||
|
||||
ASSERT_TRUE(
|
||||
SentencePieceTrainer::Train(
|
||||
absl::StrCat(
|
||||
"--model_prefix=",
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "tmp_model"),
|
||||
" --input=", input,
|
||||
" --vocab_size=8000 --normalization_rule_name=identity"
|
||||
" --model_type=bpe --control_symbols=<ctrl> "
|
||||
"--max_sentence_length=2048"))
|
||||
.ok());
|
||||
|
||||
SentencePieceProcessor sp;
|
||||
ASSERT_TRUE(sp.Load(std::string(util::JoinPath(
|
||||
absl::GetFlag(FLAGS_test_tmpdir), "tmp_model.model")))
|
||||
.ok());
|
||||
EXPECT_EQ(8000, sp.GetPieceSize());
|
||||
|
||||
const int cid = sp.PieceToId("<ctrl>");
|
||||
EXPECT_TRUE(sp.IsControl(cid));
|
||||
|
||||
std::vector<std::string> tok;
|
||||
ASSERT_TRUE(sp.Encode("", &tok).ok());
|
||||
ASSERT_TRUE(tok.empty());
|
||||
|
||||
EXPECT_TRUE(sp.Encode("吾輩《わがはい》は猫である。名前はまだ無い。"
|
||||
"どこで生れたかとんと見当《けんとう》がつかぬ。"
|
||||
"何でも薄暗いじめじめした所でニャーニャー泣いていた事だ"
|
||||
"けは記憶している"
|
||||
"。",
|
||||
&tok)
|
||||
.ok());
|
||||
EXPECT_EQ(WS
|
||||
" 吾輩 《 わが はい 》 は猫 である 。 名前 はまだ 無い 。 "
|
||||
"どこで 生 れた か とん と見 当 《 けんとう 》 が つかぬ 。 "
|
||||
"何でも 薄 暗 いじ め じ め した 所で ニャー ニャー 泣 いていた "
|
||||
"事 だけは 記憶 している 。",
|
||||
absl::StrJoin(tok, " "));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace bpe
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,599 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "builder.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
|
||||
#include "filesystem.h"
|
||||
#include "third_party/absl/strings/str_join.h"
|
||||
#include "third_party/absl/strings/str_replace.h"
|
||||
#include "third_party/absl/strings/str_split.h"
|
||||
#include "third_party/absl/strings/strip.h"
|
||||
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
#include <unicode/errorcode.h>
|
||||
#include <unicode/locid.h>
|
||||
#include <unicode/normlzr.h>
|
||||
#include <unicode/numfmt.h>
|
||||
#include <unicode/rbnf.h>
|
||||
#include <unicode/utypes.h>
|
||||
#endif // ENABLE_NFKC_COMPILE
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "normalization_rule.h"
|
||||
#include "normalizer.h"
|
||||
#include "third_party/darts_clone/darts.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace normalizer {
|
||||
namespace {
|
||||
|
||||
constexpr int kMaxUnicode = 0x10FFFF;
|
||||
|
||||
static constexpr char kDefaultNormalizerName[] = "nfkc";
|
||||
|
||||
#ifndef ENABLE_NFKC_COMPILE
|
||||
static constexpr char kCompileError[] =
|
||||
"NFK compile is not enabled. rebuild with ./configure "
|
||||
"--enable-nfkc-compile";
|
||||
#endif
|
||||
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
// Normalize `input` with ICU's normalizer with `mode`.
|
||||
Builder::Chars UnicodeNormalize(UNormalizationMode mode,
|
||||
const Builder::Chars &input) {
|
||||
const std::string utf8 = string_util::UnicodeTextToUTF8(input);
|
||||
CHECK(!utf8.empty());
|
||||
|
||||
icu::UnicodeString ustr = icu::UnicodeString::fromUTF8(utf8.c_str());
|
||||
|
||||
UErrorCode status = U_ZERO_ERROR;
|
||||
icu::UnicodeString dst;
|
||||
icu::Normalizer::normalize(ustr, mode, 0, dst, status);
|
||||
CHECK(U_SUCCESS(status));
|
||||
std::string normalized;
|
||||
normalized.reserve(dst.length() * 3);
|
||||
dst.toUTF8String(normalized);
|
||||
return string_util::UTF8ToUnicodeText(normalized);
|
||||
}
|
||||
|
||||
Builder::Chars ToNFKD(const Builder::Chars &input) {
|
||||
return UnicodeNormalize(UNORM_NFKD, input);
|
||||
}
|
||||
|
||||
Builder::Chars ToNFKC(const Builder::Chars &input) {
|
||||
return UnicodeNormalize(UNORM_NFKC, input);
|
||||
}
|
||||
|
||||
Builder::Chars ToNFC(const Builder::Chars &input) {
|
||||
return UnicodeNormalize(UNORM_NFC, input);
|
||||
}
|
||||
|
||||
Builder::Chars ToNFD(const Builder::Chars &input) {
|
||||
return UnicodeNormalize(UNORM_NFD, input);
|
||||
}
|
||||
|
||||
// Given an NFKD-normalized string, returns a set of all strings which are
|
||||
// normalized into the same `nfkd`. `norm2orig` is the normalized to
|
||||
// un-normalized character mapping.
|
||||
std::vector<Builder::Chars> ExpandUnnormalized(
|
||||
const Builder::Chars &nfkd,
|
||||
const std::map<char32, std::set<char32>> &norm2orig) {
|
||||
CHECK(!nfkd.empty());
|
||||
std::vector<Builder::Chars> results;
|
||||
for (const auto c : port::FindOrDie(norm2orig, nfkd[0])) {
|
||||
results.push_back({c});
|
||||
}
|
||||
for (size_t i = 1; i < nfkd.size(); ++i) {
|
||||
const auto &orig = port::FindOrDie(norm2orig, nfkd[i]);
|
||||
std::vector<Builder::Chars> new_results;
|
||||
for (const auto &r : results) {
|
||||
for (const auto c : orig) {
|
||||
new_results.emplace_back(r);
|
||||
new_results.back().push_back(c);
|
||||
}
|
||||
}
|
||||
results = std::move(new_results);
|
||||
}
|
||||
CHECK_EQ(nfkd.size(), results[0].size());
|
||||
return results;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Normalizes `src` with `chars_map` and returns normalized Chars.
|
||||
// `max_len` specifies the maximum length of the key in `chars_map`.
|
||||
Builder::Chars Normalize(const Builder::CharsMap &chars_map,
|
||||
const Builder::Chars &src, int max_len) {
|
||||
CHECK_GE(max_len, 1);
|
||||
Builder::Chars normalized;
|
||||
|
||||
for (size_t i = 0; i < src.size();) {
|
||||
Builder::CharsMap::const_iterator it = chars_map.end();
|
||||
const size_t slice = std::min<size_t>(i + max_len, src.size());
|
||||
// starts with the longest prefix.
|
||||
Builder::Chars key(src.begin() + i, src.begin() + slice);
|
||||
while (!key.empty()) {
|
||||
it = chars_map.find(key);
|
||||
if (it != chars_map.end()) {
|
||||
break;
|
||||
}
|
||||
key.pop_back(); // remove the last character.
|
||||
}
|
||||
|
||||
// Consumes one character when no rule is found.
|
||||
if (it == chars_map.end()) {
|
||||
normalized.push_back(src[i]);
|
||||
++i;
|
||||
} else {
|
||||
std::copy(it->second.begin(), it->second.end(),
|
||||
std::back_inserter(normalized));
|
||||
i += it->first.size();
|
||||
}
|
||||
}
|
||||
|
||||
return normalized;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
util::Status Builder::CompileCharsMap(const CharsMap &chars_map,
|
||||
std::string *output) {
|
||||
CHECK_OR_RETURN(output);
|
||||
CHECK_OR_RETURN(!chars_map.empty());
|
||||
|
||||
LOG(INFO) << "Loading CharsMap of size=" << chars_map.size();
|
||||
|
||||
// Aggregates the same target strings to save footprint.
|
||||
std::map<Chars, int> normalized2pos;
|
||||
for (const auto &p : chars_map) {
|
||||
normalized2pos[p.second] = 0;
|
||||
}
|
||||
|
||||
std::string normalized;
|
||||
for (auto &p : normalized2pos) {
|
||||
p.second = normalized.size(); // stores the pointer (position).
|
||||
const std::string utf8_out = string_util::UnicodeTextToUTF8(p.first);
|
||||
CHECK_OR_RETURN(string_util::IsStructurallyValid(utf8_out));
|
||||
normalized += utf8_out;
|
||||
normalized += '\0';
|
||||
}
|
||||
|
||||
std::vector<std::pair<std::string, int>> kv; // key-value of Trie.
|
||||
for (const auto &p : chars_map) {
|
||||
// The value of Trie stores the pointer to the normalized string.
|
||||
const std::string utf8_in = string_util::UnicodeTextToUTF8(p.first);
|
||||
CHECK_OR_RETURN(!utf8_in.empty());
|
||||
CHECK_OR_RETURN(string_util::IsStructurallyValid(utf8_in));
|
||||
kv.emplace_back(utf8_in, port::FindOrDie(normalized2pos, p.second));
|
||||
}
|
||||
|
||||
std::sort(kv.begin(), kv.end());
|
||||
std::vector<const char *> key(kv.size());
|
||||
std::vector<int> value(kv.size());
|
||||
for (size_t i = 0; i < kv.size(); ++i) {
|
||||
key[i] = kv[i].first.c_str();
|
||||
value[i] = kv[i].second;
|
||||
}
|
||||
|
||||
Darts::DoubleArray trie;
|
||||
CHECK_EQ_OR_RETURN(0, trie.build(key.size(), const_cast<char **>(&key[0]),
|
||||
nullptr, &value[0]))
|
||||
<< "cannot build double-array";
|
||||
|
||||
int max_nodes_size = 0;
|
||||
std::vector<Darts::DoubleArray::result_pair_type> results(
|
||||
2 * Normalizer::kMaxTrieResultsSize);
|
||||
for (const char *str : key) {
|
||||
const int num_nodes = trie.commonPrefixSearch(str, results.data(),
|
||||
results.size(), strlen(str));
|
||||
max_nodes_size = std::max(num_nodes, max_nodes_size);
|
||||
}
|
||||
CHECK_LT_OR_RETURN(max_nodes_size, Normalizer::kMaxTrieResultsSize)
|
||||
<< "This charmaps contain many shared prefix. "
|
||||
<< "The number of shared prefix must be less than "
|
||||
<< Normalizer::kMaxTrieResultsSize;
|
||||
|
||||
absl::string_view trie_blob(static_cast<const char *>(trie.array()),
|
||||
trie.size() * trie.unit_size());
|
||||
*output = Normalizer::EncodePrecompiledCharsMap(trie_blob, normalized);
|
||||
|
||||
LOG(INFO) << "Generated normalizer blob. size=" << output->size();
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::DecompileCharsMap(absl::string_view blob,
|
||||
Builder::CharsMap *chars_map) {
|
||||
CHECK_OR_RETURN(chars_map);
|
||||
chars_map->clear();
|
||||
|
||||
absl::string_view trie_blob, normalized;
|
||||
std::string buf;
|
||||
RETURN_IF_ERROR(Normalizer::DecodePrecompiledCharsMap(blob, &trie_blob,
|
||||
&normalized, &buf));
|
||||
|
||||
Darts::DoubleArray trie;
|
||||
trie.set_array(const_cast<char *>(trie_blob.data()),
|
||||
trie_blob.size() / trie.unit_size());
|
||||
|
||||
std::string key;
|
||||
std::function<void(size_t, size_t)> traverse;
|
||||
|
||||
// Given a Trie node at `node_pos` and the key position at `key_position`,
|
||||
// Expands children nodes from `node_pos`.
|
||||
// When leaf nodes are found, stores them into `chars_map`.
|
||||
traverse = [&traverse, &key, &trie, &normalized, &chars_map](
|
||||
size_t node_pos, size_t key_pos) -> void {
|
||||
for (int c = 0; c <= 255; ++c) {
|
||||
key.push_back(static_cast<char>(c));
|
||||
size_t copied_node_pos = node_pos;
|
||||
size_t copied_key_pos = key_pos;
|
||||
// Note: `copied_(node|key)_pos` are non-const references.
|
||||
// They store the new positions after node traversal.
|
||||
const Darts::DoubleArray::result_type result = trie.traverse(
|
||||
key.data(), copied_node_pos, copied_key_pos, key.size());
|
||||
if (result >= -1) { // node exists.
|
||||
if (result >= 0) { // has a value after transition.
|
||||
const absl::string_view value = normalized.data() + result;
|
||||
Chars key_chars, value_chars;
|
||||
for (const auto c : string_util::UTF8ToUnicodeText(key))
|
||||
key_chars.push_back(c);
|
||||
for (const auto c : string_util::UTF8ToUnicodeText(value))
|
||||
value_chars.push_back(c);
|
||||
(*chars_map)[key_chars] = value_chars;
|
||||
}
|
||||
// Recursively traverse.
|
||||
traverse(copied_node_pos, copied_key_pos);
|
||||
}
|
||||
key.pop_back();
|
||||
}
|
||||
};
|
||||
|
||||
traverse(0, 0);
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::GetPrecompiledCharsMap(absl::string_view name,
|
||||
std::string *output) {
|
||||
CHECK_OR_RETURN(output);
|
||||
|
||||
if (name == "identity") {
|
||||
output->clear();
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
std::string result;
|
||||
for (size_t i = 0; i < kNormalizationRules_size; ++i) {
|
||||
const auto *blob = &kNormalizationRules_blob[i];
|
||||
if (blob->name == name) {
|
||||
output->assign(blob->data, blob->size);
|
||||
return util::OkStatus();
|
||||
}
|
||||
}
|
||||
return util::StatusBuilder(util::StatusCode::kNotFound, GTL_LOC)
|
||||
<< "No precompiled charsmap is found: " << name;
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::BuildNFKCMap(CharsMap *chars_map) {
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
LOG(INFO) << "Running BuildNFKCMap";
|
||||
|
||||
// Set of fully NFKD decomposed characters.
|
||||
std::set<Builder::Chars> nfkd_decomposed;
|
||||
|
||||
// Fully normalized one character to unnormalized one character map.
|
||||
std::map<char32, std::set<char32>> norm2orig;
|
||||
|
||||
Builder::CharsMap nfkc_map; // The final NFKC mapping.
|
||||
|
||||
constexpr int kMaxUnicode = 0x10FFFF;
|
||||
for (char32 cp = 1; cp <= kMaxUnicode; ++cp) {
|
||||
if (!U_IS_UNICODE_CHAR(cp)) {
|
||||
continue;
|
||||
}
|
||||
// Aggregates single character to fully NFKC normalized characters.
|
||||
const auto nfkc = ToNFKC({cp});
|
||||
if (nfkc.size() >= 2 || (nfkc.size() == 1 && nfkc[0] != cp)) {
|
||||
nfkc_map[{cp}] = nfkc;
|
||||
}
|
||||
const auto nfkd = ToNFKD({cp});
|
||||
if (nfkd.size() == 1) {
|
||||
// Aggregates reverse mapping from normalized to unnormalized character.
|
||||
norm2orig[nfkd[0]].insert(cp);
|
||||
} else {
|
||||
// One character is decomposed into multiple characters.
|
||||
nfkd_decomposed.insert(nfkd);
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto &nfkd : nfkd_decomposed) {
|
||||
const auto nfkc = ToNFC(nfkd);
|
||||
// This case is already covered by single-character to NFKC mapping.
|
||||
if (nfkc == nfkd) {
|
||||
continue;
|
||||
}
|
||||
// Expand all possible sequences which are normalized into the same
|
||||
// `nfkd`.
|
||||
for (const auto &nfkd_orig : ExpandUnnormalized(nfkd, norm2orig)) {
|
||||
if (nfkd_orig != nfkc) {
|
||||
nfkc_map[nfkd_orig] = nfkc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(RemoveRedundantMap(&nfkc_map));
|
||||
*chars_map = std::move(nfkc_map);
|
||||
|
||||
#else
|
||||
LOG(ERROR) << kCompileError;
|
||||
#endif
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
util::Status Builder::BuildNmtNFKCMap(CharsMap *chars_map) {
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
LOG(INFO) << "Running BuildNmtNFKCMap";
|
||||
|
||||
CharsMap nfkc_map;
|
||||
RETURN_IF_ERROR(Builder::BuildNFKCMap(&nfkc_map));
|
||||
|
||||
// Other code points considered as whitespace.
|
||||
nfkc_map[{0x0009}] = {0x20}; // TAB
|
||||
nfkc_map[{0x000A}] = {0x20}; // LINE FEED
|
||||
nfkc_map[{0x000C}] = {0x20}; // FORM FEED
|
||||
nfkc_map[{0x000D}] = {0x20}; // CARRIAGE RETURN
|
||||
nfkc_map[{0x1680}] = {0x20}; // OGHAM SPACE MARK
|
||||
nfkc_map[{0x200B}] = {0x20}; // ZERO WIDTH SPACE
|
||||
nfkc_map[{0x200E}] = {0x20}; // LEFT-TO-RIGHT MARK
|
||||
nfkc_map[{0x200F}] = {0x20}; // RIGHT-TO-LEFT MARK
|
||||
nfkc_map[{0x2028}] = {0x20}; // LINE SEPARATOR
|
||||
nfkc_map[{0x2029}] = {0x20}; // PARAGRAPH SEPARATOR
|
||||
nfkc_map[{0x2581}] = {0x20}; // LOWER ONE EIGHT BLOCK
|
||||
nfkc_map[{0xFEFF}] = {0x20}; // ZERO WIDTH NO-BREAK
|
||||
nfkc_map[{0xFFFD}] = {0x20}; // REPLACEMENT CHARACTER
|
||||
nfkc_map[{0x200C}] = {0x20}; // ZERO WIDTH NON-JOINER
|
||||
// nfkc_map[{0x200D}] = {0x20}; // ZERO WIDTH JOINER
|
||||
|
||||
// Ascii Control characters
|
||||
nfkc_map[{0x0001}] = {};
|
||||
nfkc_map[{0x0002}] = {};
|
||||
nfkc_map[{0x0003}] = {};
|
||||
nfkc_map[{0x0004}] = {};
|
||||
nfkc_map[{0x0005}] = {};
|
||||
nfkc_map[{0x0006}] = {};
|
||||
nfkc_map[{0x0007}] = {};
|
||||
nfkc_map[{0x0008}] = {};
|
||||
nfkc_map[{0x000B}] = {};
|
||||
nfkc_map[{0x000E}] = {};
|
||||
nfkc_map[{0x000F}] = {};
|
||||
nfkc_map[{0x0010}] = {};
|
||||
nfkc_map[{0x0011}] = {};
|
||||
nfkc_map[{0x0012}] = {};
|
||||
nfkc_map[{0x0013}] = {};
|
||||
nfkc_map[{0x0014}] = {};
|
||||
nfkc_map[{0x0015}] = {};
|
||||
nfkc_map[{0x0016}] = {};
|
||||
nfkc_map[{0x0017}] = {};
|
||||
nfkc_map[{0x0018}] = {};
|
||||
nfkc_map[{0x0019}] = {};
|
||||
nfkc_map[{0x001A}] = {};
|
||||
nfkc_map[{0x001B}] = {};
|
||||
nfkc_map[{0x001C}] = {};
|
||||
nfkc_map[{0x001D}] = {};
|
||||
nfkc_map[{0x001E}] = {};
|
||||
nfkc_map[{0x001F}] = {};
|
||||
|
||||
// <control-007F>..<control-009F>
|
||||
nfkc_map[{0x007F}] = {};
|
||||
nfkc_map[{0x008F}] = {};
|
||||
nfkc_map[{0x009F}] = {};
|
||||
|
||||
// Do not normalize FULL_WIDTH TILDE, since FULL_WIDTH TILDE
|
||||
// and HALF_WIDTH TILDE are used differently in Japanese.
|
||||
nfkc_map.erase({0xFF5E});
|
||||
|
||||
RETURN_IF_ERROR(RemoveRedundantMap(&nfkc_map));
|
||||
|
||||
*chars_map = std::move(nfkc_map);
|
||||
|
||||
#else
|
||||
LOG(ERROR) << kCompileError;
|
||||
#endif
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::MergeUnicodeCaseFoldMap(Builder::CharsMap *chars_map) {
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
for (auto &c : *chars_map) {
|
||||
std::vector<char32> trg;
|
||||
for (char32 c : c.second) trg.push_back(u_foldCase(c, U_FOLD_CASE_DEFAULT));
|
||||
c.second = trg;
|
||||
}
|
||||
|
||||
constexpr int kMaxUnicode = 0x10FFFF;
|
||||
for (char32 cp = 1; cp <= kMaxUnicode; ++cp) {
|
||||
if (!U_IS_UNICODE_CHAR(cp)) {
|
||||
continue;
|
||||
}
|
||||
if (chars_map->find({cp}) != chars_map->end()) continue;
|
||||
const char32 trg = u_foldCase(cp, U_FOLD_CASE_DEFAULT);
|
||||
if (trg != cp) (*chars_map)[{cp}] = {trg};
|
||||
}
|
||||
|
||||
RETURN_IF_ERROR(RemoveRedundantMap(chars_map));
|
||||
#endif
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::BuildNFKC_CFMap(CharsMap *chars_map) {
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
CharsMap nfkc_map;
|
||||
RETURN_IF_ERROR(Builder::BuildNFKCMap(&nfkc_map));
|
||||
RETURN_IF_ERROR(Builder::MergeUnicodeCaseFoldMap(&nfkc_map));
|
||||
*chars_map = std::move(nfkc_map);
|
||||
#else
|
||||
LOG(ERROR) << kCompileError;
|
||||
#endif
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::BuildNmtNFKC_CFMap(CharsMap *chars_map) {
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
CharsMap nfkc_map;
|
||||
RETURN_IF_ERROR(Builder::BuildNmtNFKCMap(&nfkc_map));
|
||||
RETURN_IF_ERROR(Builder::MergeUnicodeCaseFoldMap(&nfkc_map));
|
||||
*chars_map = std::move(nfkc_map);
|
||||
#else
|
||||
LOG(ERROR) << kCompileError;
|
||||
#endif
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::BuildNFKDMap(CharsMap *chars_map) {
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
constexpr int kMaxUnicode = 0x10FFFF;
|
||||
for (char32 cp = 1; cp <= kMaxUnicode; ++cp) {
|
||||
if (!U_IS_UNICODE_CHAR(cp)) {
|
||||
continue;
|
||||
}
|
||||
const auto nfkd = ToNFKD({cp});
|
||||
if (nfkd.size() >= 2 || (nfkd.size() == 1 && nfkd[0] != cp)) {
|
||||
(*chars_map)[{cp}] = nfkd;
|
||||
}
|
||||
}
|
||||
#else
|
||||
LOG(ERROR) << kCompileError;
|
||||
#endif
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::LoadCharsMap(absl::string_view filename,
|
||||
CharsMap *chars_map) {
|
||||
LOG(INFO) << "Loading mapping file: " << filename.data();
|
||||
CHECK_OR_RETURN(chars_map);
|
||||
|
||||
auto input = filesystem::NewReadableFile(filename);
|
||||
|
||||
RETURN_IF_ERROR(input->status());
|
||||
|
||||
std::string line;
|
||||
chars_map->clear();
|
||||
while (input->ReadLine(&line)) {
|
||||
std::vector<std::string> fields =
|
||||
absl::StrSplit(line, '\t', absl::AllowEmpty());
|
||||
CHECK_GE(fields.size(), 1);
|
||||
if (fields.size() == 1) fields.push_back(""); // Deletion rule.
|
||||
std::vector<char32> src, trg;
|
||||
for (auto s : absl::StrSplit(fields[0], ' ')) {
|
||||
if (s.empty()) continue;
|
||||
absl::ConsumePrefix(&s, "U+");
|
||||
src.push_back(string_util::HexToInt<char32>(s));
|
||||
}
|
||||
for (auto s : absl::StrSplit(fields[1], ' ')) {
|
||||
if (s.empty()) continue;
|
||||
absl::ConsumePrefix(&s, "U+");
|
||||
trg.push_back(string_util::HexToInt<char32>(s));
|
||||
}
|
||||
CHECK_OR_RETURN(!src.empty());
|
||||
(*chars_map)[src] = trg;
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::SaveCharsMap(absl::string_view filename,
|
||||
const Builder::CharsMap &chars_map) {
|
||||
auto output = filesystem::NewWritableFile(filename);
|
||||
RETURN_IF_ERROR(output->status());
|
||||
|
||||
for (const auto &c : chars_map) {
|
||||
std::vector<std::string> src, trg;
|
||||
string_util::UnicodeText srcu, trgu;
|
||||
for (char32 v : c.first) {
|
||||
src.push_back(string_util::IntToHex(v));
|
||||
srcu.push_back(v);
|
||||
}
|
||||
for (char32 v : c.second) {
|
||||
trg.push_back(string_util::IntToHex(v));
|
||||
trgu.push_back(v);
|
||||
}
|
||||
std::string line = absl::StrJoin(src, " ") + "\t" +
|
||||
absl::StrJoin(trg, " ") + "\t# " +
|
||||
string_util::UnicodeTextToUTF8(c.first) + " => " +
|
||||
string_util::UnicodeTextToUTF8(c.second);
|
||||
line = absl::StrReplaceAll(
|
||||
line,
|
||||
{{"\b", " "}, {"\v", " "}, {"\f", " "}, {"\n", " "}, {"\r", " "}});
|
||||
output->WriteLine(line);
|
||||
}
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
|
||||
// static
|
||||
util::Status Builder::RemoveRedundantMap(CharsMap *chars_map) {
|
||||
CHECK_OR_RETURN(chars_map);
|
||||
|
||||
CharsMap new_chars_map;
|
||||
size_t max_len = 0;
|
||||
for (const auto &p : *chars_map) {
|
||||
max_len = std::max(p.first.size(), max_len);
|
||||
if (p.first.size() == 1) {
|
||||
new_chars_map.insert(p);
|
||||
}
|
||||
}
|
||||
CHECK_GT_OR_RETURN(max_len, 0);
|
||||
|
||||
// Checks whether the rules with size of `len` can be normalized by
|
||||
// the rules with size of [1 .. len - 1].
|
||||
for (size_t len = 2; len <= max_len; ++len) {
|
||||
for (const auto &p : *chars_map) {
|
||||
if (p.first.size() == len &&
|
||||
p.second != Normalize(new_chars_map, p.first, len - 1)) {
|
||||
new_chars_map.insert(p);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all characters in `chars_map` are normalized by `new_chars_map`.
|
||||
for (const auto &p : *chars_map) {
|
||||
CHECK_EQ_OR_RETURN(p.second, Normalize(new_chars_map, p.first, max_len));
|
||||
}
|
||||
|
||||
*chars_map = std::move(new_chars_map);
|
||||
|
||||
return util::OkStatus();
|
||||
}
|
||||
} // namespace normalizer
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,131 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef BUILDER_H_
|
||||
#define BUILDER_H_
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "common.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "third_party/absl/strings/string_view.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace normalizer {
|
||||
|
||||
// Builder creates a text normalization rule from user-defined string
|
||||
// to string mappings. The normalization mapping is compiled into
|
||||
// a single and compact blob index which is stored into the model proto.
|
||||
// This class also provides pre-defined rules based on Unicode NFKC.
|
||||
// https://en.wikipedia.org/wiki/Unicode_equivalence#Normalization
|
||||
class Builder {
|
||||
public:
|
||||
Builder() = delete;
|
||||
~Builder() = delete;
|
||||
|
||||
// Basic Unicode character sequence.
|
||||
using Chars = std::vector<char32>;
|
||||
|
||||
// String-to-string mapping.
|
||||
using CharsMap = std::map<Chars, Chars>;
|
||||
|
||||
static util::Status CompileCharsMap(const CharsMap &chars_map,
|
||||
std::string *output);
|
||||
|
||||
// Decompiles `blob` into `chars_map`.
|
||||
static util::Status DecompileCharsMap(absl::string_view blob,
|
||||
CharsMap *chars_map);
|
||||
|
||||
// Returns a pre-compiled binary index with `name`.
|
||||
static util::Status GetPrecompiledCharsMap(absl::string_view name,
|
||||
std::string *output);
|
||||
|
||||
// Makes a normalization mapping based on NFKC.
|
||||
//
|
||||
// Note that Normalizer/Builder classes do not support
|
||||
// full NFKC normalization, since full NFKC normalization cannot
|
||||
// be implemented with a simple longest matching string-to-string
|
||||
// replacement. One unsupported normalization is multiple combining
|
||||
// marks.
|
||||
//
|
||||
// Strings with multiple combining marks cannot correctly
|
||||
// be normalized, because it needs to sort the combining marks
|
||||
// with Canonical_Combining_Class (CCC).
|
||||
// http://unicode.org/reports/tr15/#Multiple_Mark_Figure
|
||||
//
|
||||
// Example:
|
||||
// Original: U+1E0B U+0323
|
||||
// Decomposed: U+0064 U+0307 U+0323
|
||||
// NFKD: U+0064 U+0323 U+0307 (Combining characters are sorted by CCC)
|
||||
// NFKC: U+1E0D U+0307 (U+0064 U+0323 => U+1E0D)
|
||||
//
|
||||
// To support the normalization above with a longest matching, we need to
|
||||
// enumerate all possible permutations of combining marks in advance,
|
||||
// which is not feasible. For example, suppose the case there are three
|
||||
// combining marks X, Y and Z, which are sorted into one canonical order
|
||||
// Z, Y, X with NFK(D|C). In this case, all permutations (XYZ, XZY, YXZ...)
|
||||
// are normalized into ZYX. When we implement this normalization with
|
||||
// a longest matching, we need to have 3! rules. XYZ=>ZYX, XZY=>ZYX..
|
||||
// Since Unicode has more than 100 combining characters, it is not possible
|
||||
// to expand all permutations.
|
||||
//
|
||||
// We will not implement the full NFKC in SentencePiece because
|
||||
// 1) It is unusual to see decomposed Unicode characters in real text.
|
||||
// 2) Providing a flexible, user-customizable, and self-contained
|
||||
// normalizer is the goal of SentencePiece.
|
||||
//
|
||||
// TODO(taku): Make NFC, NFD, and NFKD mapping if necessary.
|
||||
static util::Status BuildNFKCMap(CharsMap *chars_map);
|
||||
|
||||
// Makes an NFKC-based mapping with NMT specific modifications around
|
||||
// whitespaces.
|
||||
static util::Status BuildNmtNFKCMap(CharsMap *chars_map);
|
||||
|
||||
// Merge Unicode case folding mapping into `chars_map`.
|
||||
static util::Status MergeUnicodeCaseFoldMap(CharsMap *chars_map);
|
||||
|
||||
// Makes NFKC with Unicode case folding.
|
||||
static util::Status BuildNFKC_CFMap(CharsMap *chars_map);
|
||||
|
||||
// Makes NMT NFKC with Unicode case folding.
|
||||
static util::Status BuildNmtNFKC_CFMap(CharsMap *chars_map);
|
||||
|
||||
// Given NFKC maps, convert them to NFKD.
|
||||
static util::Status BuildNFKDMap(CharsMap *chars_map);
|
||||
|
||||
// Builds Chars map save in `filename`.
|
||||
// Format:
|
||||
// src_uchar1 src_uchar2 ... <tab> trg_uchar1 trg_uchar2...
|
||||
// (src|trg)_ucharX must be a hex of Unicode code point.
|
||||
static util::Status LoadCharsMap(absl::string_view filename,
|
||||
CharsMap *chars_map);
|
||||
|
||||
// Saves Chars map to `filename` as TSV.
|
||||
static util::Status SaveCharsMap(absl::string_view filename,
|
||||
const CharsMap &chars_map);
|
||||
|
||||
private:
|
||||
FRIEND_TEST(BuilderTest, RemoveRedundantMapTest);
|
||||
|
||||
// Removes redundant rules from `chars_map`.
|
||||
// When char_maps have "aa" => "bb" and "a" => "b", the first
|
||||
// rule is not necessary since the second rule can cover the first rule.
|
||||
static util::Status RemoveRedundantMap(CharsMap *chars_map);
|
||||
};
|
||||
} // namespace normalizer
|
||||
} // namespace sentencepiece
|
||||
#endif // BUILDER_H_
|
||||
|
|
@ -1,228 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "builder.h"
|
||||
#include "common.h"
|
||||
#include "filesystem.h"
|
||||
#include "normalizer.h"
|
||||
#include "sentencepiece_trainer.h"
|
||||
#include "testharness.h"
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace normalizer {
|
||||
|
||||
// Space symbol
|
||||
#define WS "\xe2\x96\x81"
|
||||
|
||||
TEST(BuilderTest, RemoveRedundantMapTest) {
|
||||
Builder::CharsMap chars_map;
|
||||
|
||||
// ab => AB, a => A, b => B, abc => BCA
|
||||
chars_map[{0x0061}] = {0x0041};
|
||||
chars_map[{0x0062}] = {0x0042};
|
||||
chars_map[{0x0061, 0x0062}] = {0x0041, 0x0042};
|
||||
chars_map[{0x0061, 0x0062, 0x0063}] = {0x0043, 0x0042, 0x0041};
|
||||
|
||||
EXPECT_TRUE(Builder::RemoveRedundantMap(&chars_map).ok());
|
||||
EXPECT_EQ(3, chars_map.size());
|
||||
EXPECT_EQ(chars_map.end(), chars_map.find({0x0061, 0x0062}));
|
||||
EXPECT_NE(chars_map.end(), chars_map.find({0x0061}));
|
||||
EXPECT_NE(chars_map.end(), chars_map.find({0x0062}));
|
||||
EXPECT_NE(chars_map.end(), chars_map.find({0x0061, 0x0062, 0x0063}));
|
||||
}
|
||||
|
||||
TEST(BuilderTest, GetPrecompiledCharsMapWithInvalidNameTest) {
|
||||
std::string output;
|
||||
EXPECT_FALSE(Builder::GetPrecompiledCharsMap("", &output).ok());
|
||||
EXPECT_FALSE(Builder::GetPrecompiledCharsMap("__UNKNOWN__", &output).ok());
|
||||
}
|
||||
|
||||
TEST(BuilderTest, BuildNFKCMapTest) {
|
||||
Builder::CharsMap chars_map;
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
EXPECT_TRUE(Builder::BuildNFKCMap(&chars_map).ok());
|
||||
EXPECT_TRUE(!chars_map.empty());
|
||||
#else
|
||||
EXPECT_TRUE(Builder::BuildNFKCMap(&chars_map).ok());
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(BuilderTest, GetPrecompiledCharsMapTest) {
|
||||
{
|
||||
const NormalizerSpec spec =
|
||||
SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc");
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_EQ(WS "ABC", normalizer.Normalize("ABC"));
|
||||
EXPECT_EQ(WS "(株)", normalizer.Normalize("㈱"));
|
||||
EXPECT_EQ(WS "グーグル", normalizer.Normalize("グーグル"));
|
||||
}
|
||||
|
||||
{
|
||||
const NormalizerSpec spec =
|
||||
SentencePieceTrainer::GetNormalizerSpec("nfkc_cf");
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_EQ(WS "abc", normalizer.Normalize("ABC"));
|
||||
EXPECT_EQ(WS "abc", normalizer.Normalize("ABC"));
|
||||
}
|
||||
|
||||
{
|
||||
const NormalizerSpec spec =
|
||||
SentencePieceTrainer::GetNormalizerSpec("nmt_nfkc_cf");
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_EQ(WS "abc", normalizer.Normalize("ABC"));
|
||||
EXPECT_EQ(WS "abc", normalizer.Normalize("ABC"));
|
||||
}
|
||||
|
||||
{
|
||||
const NormalizerSpec spec =
|
||||
SentencePieceTrainer::GetNormalizerSpec("identity");
|
||||
EXPECT_TRUE(spec.precompiled_charsmap().empty());
|
||||
const Normalizer normalizer(spec);
|
||||
EXPECT_EQ(WS "ABC", normalizer.Normalize("ABC"));
|
||||
EXPECT_EQ(WS "㈱", normalizer.Normalize("㈱"));
|
||||
EXPECT_EQ(WS "グーグル", normalizer.Normalize("グーグル"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(BuilderTest, CompileCharsMap) {
|
||||
Builder::CharsMap chars_map;
|
||||
|
||||
// Lowercase => Uppercase
|
||||
for (char32 lc = static_cast<char32>('a'); lc <= static_cast<char32>('z');
|
||||
++lc) {
|
||||
const char32 uc = lc + 'A' - 'a';
|
||||
chars_map[{lc}] = {uc};
|
||||
}
|
||||
|
||||
// あいう => abc
|
||||
chars_map[{0x3042, 0x3044, 0x3046}] = {0x0061, 0x0062, 0x0063};
|
||||
|
||||
// えお => remove
|
||||
chars_map[{0x3048, 0x304A}] = {};
|
||||
|
||||
NormalizerSpec spec;
|
||||
EXPECT_TRUE(
|
||||
Builder::CompileCharsMap(chars_map, spec.mutable_precompiled_charsmap())
|
||||
.ok());
|
||||
Builder::CharsMap decompiled_chars_map;
|
||||
EXPECT_TRUE(Builder::DecompileCharsMap(spec.precompiled_charsmap(),
|
||||
&decompiled_chars_map)
|
||||
.ok());
|
||||
EXPECT_EQ(chars_map, decompiled_chars_map);
|
||||
|
||||
spec.set_add_dummy_prefix(false);
|
||||
const Normalizer normalizer(spec);
|
||||
|
||||
EXPECT_EQ("ABC", normalizer.Normalize("abc"));
|
||||
EXPECT_EQ("ABC", normalizer.Normalize("ABC"));
|
||||
EXPECT_EQ("XY" WS "Z", normalizer.Normalize("xy z"));
|
||||
|
||||
EXPECT_EQ("あ", normalizer.Normalize("あ"));
|
||||
EXPECT_EQ("abc", normalizer.Normalize("あいう"));
|
||||
EXPECT_EQ("abcえ", normalizer.Normalize("あいうえ"));
|
||||
EXPECT_EQ("ABCabcD", normalizer.Normalize("abcあいうd"));
|
||||
EXPECT_EQ("abcか", normalizer.Normalize("あいうえおか"));
|
||||
}
|
||||
|
||||
static constexpr char kTestInputData[] = "nfkc.tsv";
|
||||
|
||||
TEST(BuilderTest, LoadCharsMapTest) {
|
||||
Builder::CharsMap chars_map;
|
||||
ASSERT_TRUE(
|
||||
Builder::LoadCharsMap(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_srcdir), kTestInputData),
|
||||
&chars_map)
|
||||
.ok());
|
||||
|
||||
std::string precompiled, expected;
|
||||
ASSERT_TRUE(Builder::CompileCharsMap(chars_map, &precompiled).ok());
|
||||
|
||||
// Round-trip.
|
||||
Builder::CharsMap decompiled_chars_map;
|
||||
ASSERT_TRUE(
|
||||
Builder::DecompileCharsMap(precompiled, &decompiled_chars_map).ok());
|
||||
EXPECT_EQ(chars_map, decompiled_chars_map);
|
||||
|
||||
ASSERT_TRUE(
|
||||
Builder::SaveCharsMap(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "output.tsv"),
|
||||
chars_map)
|
||||
.ok());
|
||||
|
||||
Builder::CharsMap saved_chars_map;
|
||||
ASSERT_TRUE(
|
||||
Builder::LoadCharsMap(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "output.tsv"),
|
||||
&saved_chars_map)
|
||||
.ok());
|
||||
EXPECT_EQ(chars_map, saved_chars_map);
|
||||
|
||||
#ifdef ENABLE_NFKC_COMPILE
|
||||
Builder::CharsMap nfkc_map;
|
||||
ASSERT_TRUE(Builder::BuildNFKCMap(&nfkc_map).ok());
|
||||
ASSERT_TRUE(Builder::CompileCharsMap(nfkc_map, &expected).ok());
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(BuilderTest, LoadCharsMapWithEmptyeTest) {
|
||||
{
|
||||
auto output = filesystem::NewWritableFile(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test.tsv"));
|
||||
output->WriteLine("0061\t0041");
|
||||
output->WriteLine("0062");
|
||||
output->WriteLine("0063\t\t#foo=>bar");
|
||||
}
|
||||
|
||||
Builder::CharsMap chars_map;
|
||||
EXPECT_TRUE(Builder::LoadCharsMap(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test.tsv"),
|
||||
&chars_map)
|
||||
.ok());
|
||||
|
||||
EXPECT_EQ(3, chars_map.size());
|
||||
EXPECT_EQ(std::vector<char32>({0x0041}), chars_map[{0x0061}]);
|
||||
EXPECT_EQ(std::vector<char32>({}), chars_map[{0x0062}]);
|
||||
EXPECT_EQ(std::vector<char32>({}), chars_map[{0x0063}]);
|
||||
|
||||
EXPECT_TRUE(
|
||||
Builder::SaveCharsMap(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_out.tsv"),
|
||||
chars_map)
|
||||
.ok());
|
||||
|
||||
Builder::CharsMap new_chars_map;
|
||||
EXPECT_TRUE(
|
||||
Builder::LoadCharsMap(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_out.tsv"),
|
||||
&new_chars_map)
|
||||
.ok());
|
||||
EXPECT_EQ(chars_map, new_chars_map);
|
||||
}
|
||||
|
||||
TEST(BuilderTest, ContainsTooManySharedPrefixTest) {
|
||||
Builder::CharsMap chars_map;
|
||||
std::vector<char32> keys;
|
||||
// chars_map contains too many shared prefix ("aaaa...");
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
keys.push_back('a');
|
||||
chars_map[keys] = {'b'};
|
||||
}
|
||||
std::string output;
|
||||
EXPECT_FALSE(Builder::CompileCharsMap(chars_map, &output).ok());
|
||||
}
|
||||
|
||||
} // namespace normalizer
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,923 +0,0 @@
|
|||
// Generated by the protocol buffer compiler. DO NOT EDIT!
|
||||
// source: sentencepiece.proto
|
||||
|
||||
#include "sentencepiece.pb.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include <google/protobuf/io/coded_stream.h>
|
||||
#include <google/protobuf/extension_set.h>
|
||||
#include <google/protobuf/wire_format_lite.h>
|
||||
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
|
||||
// @@protoc_insertion_point(includes)
|
||||
#include <google/protobuf/port_def.inc>
|
||||
extern PROTOBUF_INTERNAL_EXPORT_sentencepiece_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_SentencePieceText_sentencepiece_2eproto;
|
||||
extern PROTOBUF_INTERNAL_EXPORT_sentencepiece_2eproto ::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SentencePieceText_SentencePiece_sentencepiece_2eproto;
|
||||
namespace sentencepiece {
|
||||
class SentencePieceText_SentencePieceDefaultTypeInternal {
|
||||
public:
|
||||
::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed<SentencePieceText_SentencePiece> _instance;
|
||||
} _SentencePieceText_SentencePiece_default_instance_;
|
||||
class SentencePieceTextDefaultTypeInternal {
|
||||
public:
|
||||
::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed<SentencePieceText> _instance;
|
||||
} _SentencePieceText_default_instance_;
|
||||
class NBestSentencePieceTextDefaultTypeInternal {
|
||||
public:
|
||||
::PROTOBUF_NAMESPACE_ID::internal::ExplicitlyConstructed<NBestSentencePieceText> _instance;
|
||||
} _NBestSentencePieceText_default_instance_;
|
||||
} // namespace sentencepiece
|
||||
static void InitDefaultsscc_info_NBestSentencePieceText_sentencepiece_2eproto() {
|
||||
GOOGLE_PROTOBUF_VERIFY_VERSION;
|
||||
|
||||
{
|
||||
void* ptr = &::sentencepiece::_NBestSentencePieceText_default_instance_;
|
||||
new (ptr) ::sentencepiece::NBestSentencePieceText();
|
||||
::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_NBestSentencePieceText_sentencepiece_2eproto =
|
||||
{{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_NBestSentencePieceText_sentencepiece_2eproto}, {
|
||||
&scc_info_SentencePieceText_sentencepiece_2eproto.base,}};
|
||||
|
||||
static void InitDefaultsscc_info_SentencePieceText_sentencepiece_2eproto() {
|
||||
GOOGLE_PROTOBUF_VERIFY_VERSION;
|
||||
|
||||
{
|
||||
void* ptr = &::sentencepiece::_SentencePieceText_default_instance_;
|
||||
new (ptr) ::sentencepiece::SentencePieceText();
|
||||
::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<1> scc_info_SentencePieceText_sentencepiece_2eproto =
|
||||
{{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 1, 0, InitDefaultsscc_info_SentencePieceText_sentencepiece_2eproto}, {
|
||||
&scc_info_SentencePieceText_SentencePiece_sentencepiece_2eproto.base,}};
|
||||
|
||||
static void InitDefaultsscc_info_SentencePieceText_SentencePiece_sentencepiece_2eproto() {
|
||||
GOOGLE_PROTOBUF_VERIFY_VERSION;
|
||||
|
||||
{
|
||||
void* ptr = &::sentencepiece::_SentencePieceText_SentencePiece_default_instance_;
|
||||
new (ptr) ::sentencepiece::SentencePieceText_SentencePiece();
|
||||
::PROTOBUF_NAMESPACE_ID::internal::OnShutdownDestroyMessage(ptr);
|
||||
}
|
||||
}
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::internal::SCCInfo<0> scc_info_SentencePieceText_SentencePiece_sentencepiece_2eproto =
|
||||
{{ATOMIC_VAR_INIT(::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase::kUninitialized), 0, 0, InitDefaultsscc_info_SentencePieceText_SentencePiece_sentencepiece_2eproto}, {}};
|
||||
|
||||
namespace sentencepiece {
|
||||
|
||||
// ===================================================================
|
||||
|
||||
class SentencePieceText_SentencePiece::_Internal {
|
||||
public:
|
||||
using HasBits = decltype(std::declval<SentencePieceText_SentencePiece>()._has_bits_);
|
||||
static void set_has_piece(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 1u;
|
||||
}
|
||||
static void set_has_id(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 4u;
|
||||
}
|
||||
static void set_has_surface(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 2u;
|
||||
}
|
||||
static void set_has_begin(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 8u;
|
||||
}
|
||||
static void set_has_end(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 16u;
|
||||
}
|
||||
};
|
||||
|
||||
SentencePieceText_SentencePiece::SentencePieceText_SentencePiece(::PROTOBUF_NAMESPACE_ID::Arena* arena)
|
||||
: ::PROTOBUF_NAMESPACE_ID::MessageLite(arena),
|
||||
_extensions_(arena) {
|
||||
SharedCtor();
|
||||
RegisterArenaDtor(arena);
|
||||
// @@protoc_insertion_point(arena_constructor:sentencepiece.SentencePieceText.SentencePiece)
|
||||
}
|
||||
SentencePieceText_SentencePiece::SentencePieceText_SentencePiece(const SentencePieceText_SentencePiece& from)
|
||||
: ::PROTOBUF_NAMESPACE_ID::MessageLite(),
|
||||
_has_bits_(from._has_bits_) {
|
||||
_internal_metadata_.MergeFrom<std::string>(from._internal_metadata_);
|
||||
_extensions_.MergeFrom(from._extensions_);
|
||||
piece_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
if (from._internal_has_piece()) {
|
||||
piece_.Set(::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr::EmptyDefault{}, from._internal_piece(),
|
||||
GetArena());
|
||||
}
|
||||
surface_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
if (from._internal_has_surface()) {
|
||||
surface_.Set(::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr::EmptyDefault{}, from._internal_surface(),
|
||||
GetArena());
|
||||
}
|
||||
::memcpy(&id_, &from.id_,
|
||||
static_cast<size_t>(reinterpret_cast<char*>(&end_) -
|
||||
reinterpret_cast<char*>(&id_)) + sizeof(end_));
|
||||
// @@protoc_insertion_point(copy_constructor:sentencepiece.SentencePieceText.SentencePiece)
|
||||
}
|
||||
|
||||
void SentencePieceText_SentencePiece::SharedCtor() {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_SentencePieceText_SentencePiece_sentencepiece_2eproto.base);
|
||||
piece_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
surface_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
::memset(reinterpret_cast<char*>(this) + static_cast<size_t>(
|
||||
reinterpret_cast<char*>(&id_) - reinterpret_cast<char*>(this)),
|
||||
0, static_cast<size_t>(reinterpret_cast<char*>(&end_) -
|
||||
reinterpret_cast<char*>(&id_)) + sizeof(end_));
|
||||
}
|
||||
|
||||
SentencePieceText_SentencePiece::~SentencePieceText_SentencePiece() {
|
||||
// @@protoc_insertion_point(destructor:sentencepiece.SentencePieceText.SentencePiece)
|
||||
SharedDtor();
|
||||
_internal_metadata_.Delete<std::string>();
|
||||
}
|
||||
|
||||
void SentencePieceText_SentencePiece::SharedDtor() {
|
||||
GOOGLE_DCHECK(GetArena() == nullptr);
|
||||
piece_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
surface_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
}
|
||||
|
||||
void SentencePieceText_SentencePiece::ArenaDtor(void* object) {
|
||||
SentencePieceText_SentencePiece* _this = reinterpret_cast< SentencePieceText_SentencePiece* >(object);
|
||||
(void)_this;
|
||||
}
|
||||
void SentencePieceText_SentencePiece::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) {
|
||||
}
|
||||
void SentencePieceText_SentencePiece::SetCachedSize(int size) const {
|
||||
_cached_size_.Set(size);
|
||||
}
|
||||
const SentencePieceText_SentencePiece& SentencePieceText_SentencePiece::default_instance() {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_SentencePieceText_SentencePiece_sentencepiece_2eproto.base);
|
||||
return *internal_default_instance();
|
||||
}
|
||||
|
||||
|
||||
void SentencePieceText_SentencePiece::Clear() {
|
||||
// @@protoc_insertion_point(message_clear_start:sentencepiece.SentencePieceText.SentencePiece)
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
// Prevent compiler warnings about cached_has_bits being unused
|
||||
(void) cached_has_bits;
|
||||
|
||||
_extensions_.Clear();
|
||||
cached_has_bits = _has_bits_[0];
|
||||
if (cached_has_bits & 0x00000003u) {
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
piece_.ClearNonDefaultToEmpty();
|
||||
}
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
surface_.ClearNonDefaultToEmpty();
|
||||
}
|
||||
}
|
||||
if (cached_has_bits & 0x0000001cu) {
|
||||
::memset(&id_, 0, static_cast<size_t>(
|
||||
reinterpret_cast<char*>(&end_) -
|
||||
reinterpret_cast<char*>(&id_)) + sizeof(end_));
|
||||
}
|
||||
_has_bits_.Clear();
|
||||
_internal_metadata_.Clear<std::string>();
|
||||
}
|
||||
|
||||
const char* SentencePieceText_SentencePiece::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) {
|
||||
#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure
|
||||
_Internal::HasBits has_bits{};
|
||||
while (!ctx->Done(&ptr)) {
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 tag;
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag);
|
||||
CHK_(ptr);
|
||||
switch (tag >> 3) {
|
||||
// optional string piece = 1;
|
||||
case 1:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) {
|
||||
auto str = _internal_mutable_piece();
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx);
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// optional uint32 id = 2;
|
||||
case 2:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 16)) {
|
||||
_Internal::set_has_id(&has_bits);
|
||||
id_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// optional string surface = 3;
|
||||
case 3:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 26)) {
|
||||
auto str = _internal_mutable_surface();
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx);
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// optional uint32 begin = 4;
|
||||
case 4:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 32)) {
|
||||
_Internal::set_has_begin(&has_bits);
|
||||
begin_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// optional uint32 end = 5;
|
||||
case 5:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 40)) {
|
||||
_Internal::set_has_end(&has_bits);
|
||||
end_ = ::PROTOBUF_NAMESPACE_ID::internal::ReadVarint32(&ptr);
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
default: {
|
||||
handle_unusual:
|
||||
if ((tag & 7) == 4 || tag == 0) {
|
||||
ctx->SetLastTag(tag);
|
||||
goto success;
|
||||
}
|
||||
if ((1600u <= tag)) {
|
||||
ptr = _extensions_.ParseField(tag, ptr,
|
||||
internal_default_instance(), &_internal_metadata_, ctx);
|
||||
CHK_(ptr != nullptr);
|
||||
continue;
|
||||
}
|
||||
ptr = UnknownFieldParse(tag,
|
||||
_internal_metadata_.mutable_unknown_fields<std::string>(),
|
||||
ptr, ctx);
|
||||
CHK_(ptr != nullptr);
|
||||
continue;
|
||||
}
|
||||
} // switch
|
||||
} // while
|
||||
success:
|
||||
_has_bits_.Or(has_bits);
|
||||
return ptr;
|
||||
failure:
|
||||
ptr = nullptr;
|
||||
goto success;
|
||||
#undef CHK_
|
||||
}
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::uint8* SentencePieceText_SentencePiece::_InternalSerialize(
|
||||
::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const {
|
||||
// @@protoc_insertion_point(serialize_to_array_start:sentencepiece.SentencePieceText.SentencePiece)
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
cached_has_bits = _has_bits_[0];
|
||||
// optional string piece = 1;
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
target = stream->WriteStringMaybeAliased(
|
||||
1, this->_internal_piece(), target);
|
||||
}
|
||||
|
||||
// optional uint32 id = 2;
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt32ToArray(2, this->_internal_id(), target);
|
||||
}
|
||||
|
||||
// optional string surface = 3;
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
target = stream->WriteStringMaybeAliased(
|
||||
3, this->_internal_surface(), target);
|
||||
}
|
||||
|
||||
// optional uint32 begin = 4;
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt32ToArray(4, this->_internal_begin(), target);
|
||||
}
|
||||
|
||||
// optional uint32 end = 5;
|
||||
if (cached_has_bits & 0x00000010u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteUInt32ToArray(5, this->_internal_end(), target);
|
||||
}
|
||||
|
||||
// Extension range [200, 536870912)
|
||||
target = _extensions_._InternalSerialize(
|
||||
200, 536870912, target, stream);
|
||||
|
||||
if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) {
|
||||
target = stream->WriteRaw(_internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).data(),
|
||||
static_cast<int>(_internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).size()), target);
|
||||
}
|
||||
// @@protoc_insertion_point(serialize_to_array_end:sentencepiece.SentencePieceText.SentencePiece)
|
||||
return target;
|
||||
}
|
||||
|
||||
size_t SentencePieceText_SentencePiece::ByteSizeLong() const {
|
||||
// @@protoc_insertion_point(message_byte_size_start:sentencepiece.SentencePieceText.SentencePiece)
|
||||
size_t total_size = 0;
|
||||
|
||||
total_size += _extensions_.ByteSize();
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
// Prevent compiler warnings about cached_has_bits being unused
|
||||
(void) cached_has_bits;
|
||||
|
||||
cached_has_bits = _has_bits_[0];
|
||||
if (cached_has_bits & 0x0000001fu) {
|
||||
// optional string piece = 1;
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize(
|
||||
this->_internal_piece());
|
||||
}
|
||||
|
||||
// optional string surface = 3;
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize(
|
||||
this->_internal_surface());
|
||||
}
|
||||
|
||||
// optional uint32 id = 2;
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::UInt32Size(
|
||||
this->_internal_id());
|
||||
}
|
||||
|
||||
// optional uint32 begin = 4;
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::UInt32Size(
|
||||
this->_internal_begin());
|
||||
}
|
||||
|
||||
// optional uint32 end = 5;
|
||||
if (cached_has_bits & 0x00000010u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::UInt32Size(
|
||||
this->_internal_end());
|
||||
}
|
||||
|
||||
}
|
||||
if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) {
|
||||
total_size += _internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).size();
|
||||
}
|
||||
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size);
|
||||
SetCachedSize(cached_size);
|
||||
return total_size;
|
||||
}
|
||||
|
||||
void SentencePieceText_SentencePiece::CheckTypeAndMergeFrom(
|
||||
const ::PROTOBUF_NAMESPACE_ID::MessageLite& from) {
|
||||
MergeFrom(*::PROTOBUF_NAMESPACE_ID::internal::DownCast<const SentencePieceText_SentencePiece*>(
|
||||
&from));
|
||||
}
|
||||
|
||||
void SentencePieceText_SentencePiece::MergeFrom(const SentencePieceText_SentencePiece& from) {
|
||||
// @@protoc_insertion_point(class_specific_merge_from_start:sentencepiece.SentencePieceText.SentencePiece)
|
||||
GOOGLE_DCHECK_NE(&from, this);
|
||||
_extensions_.MergeFrom(from._extensions_);
|
||||
_internal_metadata_.MergeFrom<std::string>(from._internal_metadata_);
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
cached_has_bits = from._has_bits_[0];
|
||||
if (cached_has_bits & 0x0000001fu) {
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
_internal_set_piece(from._internal_piece());
|
||||
}
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
_internal_set_surface(from._internal_surface());
|
||||
}
|
||||
if (cached_has_bits & 0x00000004u) {
|
||||
id_ = from.id_;
|
||||
}
|
||||
if (cached_has_bits & 0x00000008u) {
|
||||
begin_ = from.begin_;
|
||||
}
|
||||
if (cached_has_bits & 0x00000010u) {
|
||||
end_ = from.end_;
|
||||
}
|
||||
_has_bits_[0] |= cached_has_bits;
|
||||
}
|
||||
}
|
||||
|
||||
void SentencePieceText_SentencePiece::CopyFrom(const SentencePieceText_SentencePiece& from) {
|
||||
// @@protoc_insertion_point(class_specific_copy_from_start:sentencepiece.SentencePieceText.SentencePiece)
|
||||
if (&from == this) return;
|
||||
Clear();
|
||||
MergeFrom(from);
|
||||
}
|
||||
|
||||
bool SentencePieceText_SentencePiece::IsInitialized() const {
|
||||
if (!_extensions_.IsInitialized()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
void SentencePieceText_SentencePiece::InternalSwap(SentencePieceText_SentencePiece* other) {
|
||||
using std::swap;
|
||||
_extensions_.Swap(&other->_extensions_);
|
||||
_internal_metadata_.Swap<std::string>(&other->_internal_metadata_);
|
||||
swap(_has_bits_[0], other->_has_bits_[0]);
|
||||
piece_.Swap(&other->piece_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena());
|
||||
surface_.Swap(&other->surface_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena());
|
||||
::PROTOBUF_NAMESPACE_ID::internal::memswap<
|
||||
PROTOBUF_FIELD_OFFSET(SentencePieceText_SentencePiece, end_)
|
||||
+ sizeof(SentencePieceText_SentencePiece::end_)
|
||||
- PROTOBUF_FIELD_OFFSET(SentencePieceText_SentencePiece, id_)>(
|
||||
reinterpret_cast<char*>(&id_),
|
||||
reinterpret_cast<char*>(&other->id_));
|
||||
}
|
||||
|
||||
std::string SentencePieceText_SentencePiece::GetTypeName() const {
|
||||
return "sentencepiece.SentencePieceText.SentencePiece";
|
||||
}
|
||||
|
||||
|
||||
// ===================================================================
|
||||
|
||||
class SentencePieceText::_Internal {
|
||||
public:
|
||||
using HasBits = decltype(std::declval<SentencePieceText>()._has_bits_);
|
||||
static void set_has_text(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 1u;
|
||||
}
|
||||
static void set_has_score(HasBits* has_bits) {
|
||||
(*has_bits)[0] |= 2u;
|
||||
}
|
||||
};
|
||||
|
||||
SentencePieceText::SentencePieceText(::PROTOBUF_NAMESPACE_ID::Arena* arena)
|
||||
: ::PROTOBUF_NAMESPACE_ID::MessageLite(arena),
|
||||
_extensions_(arena),
|
||||
pieces_(arena) {
|
||||
SharedCtor();
|
||||
RegisterArenaDtor(arena);
|
||||
// @@protoc_insertion_point(arena_constructor:sentencepiece.SentencePieceText)
|
||||
}
|
||||
SentencePieceText::SentencePieceText(const SentencePieceText& from)
|
||||
: ::PROTOBUF_NAMESPACE_ID::MessageLite(),
|
||||
_has_bits_(from._has_bits_),
|
||||
pieces_(from.pieces_) {
|
||||
_internal_metadata_.MergeFrom<std::string>(from._internal_metadata_);
|
||||
_extensions_.MergeFrom(from._extensions_);
|
||||
text_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
if (from._internal_has_text()) {
|
||||
text_.Set(::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr::EmptyDefault{}, from._internal_text(),
|
||||
GetArena());
|
||||
}
|
||||
score_ = from.score_;
|
||||
// @@protoc_insertion_point(copy_constructor:sentencepiece.SentencePieceText)
|
||||
}
|
||||
|
||||
void SentencePieceText::SharedCtor() {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_SentencePieceText_sentencepiece_2eproto.base);
|
||||
text_.UnsafeSetDefault(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
score_ = 0;
|
||||
}
|
||||
|
||||
SentencePieceText::~SentencePieceText() {
|
||||
// @@protoc_insertion_point(destructor:sentencepiece.SentencePieceText)
|
||||
SharedDtor();
|
||||
_internal_metadata_.Delete<std::string>();
|
||||
}
|
||||
|
||||
void SentencePieceText::SharedDtor() {
|
||||
GOOGLE_DCHECK(GetArena() == nullptr);
|
||||
text_.DestroyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
|
||||
}
|
||||
|
||||
void SentencePieceText::ArenaDtor(void* object) {
|
||||
SentencePieceText* _this = reinterpret_cast< SentencePieceText* >(object);
|
||||
(void)_this;
|
||||
}
|
||||
void SentencePieceText::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) {
|
||||
}
|
||||
void SentencePieceText::SetCachedSize(int size) const {
|
||||
_cached_size_.Set(size);
|
||||
}
|
||||
const SentencePieceText& SentencePieceText::default_instance() {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_SentencePieceText_sentencepiece_2eproto.base);
|
||||
return *internal_default_instance();
|
||||
}
|
||||
|
||||
|
||||
void SentencePieceText::Clear() {
|
||||
// @@protoc_insertion_point(message_clear_start:sentencepiece.SentencePieceText)
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
// Prevent compiler warnings about cached_has_bits being unused
|
||||
(void) cached_has_bits;
|
||||
|
||||
_extensions_.Clear();
|
||||
pieces_.Clear();
|
||||
cached_has_bits = _has_bits_[0];
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
text_.ClearNonDefaultToEmpty();
|
||||
}
|
||||
score_ = 0;
|
||||
_has_bits_.Clear();
|
||||
_internal_metadata_.Clear<std::string>();
|
||||
}
|
||||
|
||||
const char* SentencePieceText::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) {
|
||||
#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure
|
||||
_Internal::HasBits has_bits{};
|
||||
while (!ctx->Done(&ptr)) {
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 tag;
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag);
|
||||
CHK_(ptr);
|
||||
switch (tag >> 3) {
|
||||
// optional string text = 1;
|
||||
case 1:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) {
|
||||
auto str = _internal_mutable_text();
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::InlineGreedyStringParser(str, ptr, ctx);
|
||||
CHK_(ptr);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// repeated .sentencepiece.SentencePieceText.SentencePiece pieces = 2;
|
||||
case 2:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 18)) {
|
||||
ptr -= 1;
|
||||
do {
|
||||
ptr += 1;
|
||||
ptr = ctx->ParseMessage(_internal_add_pieces(), ptr);
|
||||
CHK_(ptr);
|
||||
if (!ctx->DataAvailable(ptr)) break;
|
||||
} while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<18>(ptr));
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
// optional float score = 3;
|
||||
case 3:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 29)) {
|
||||
_Internal::set_has_score(&has_bits);
|
||||
score_ = ::PROTOBUF_NAMESPACE_ID::internal::UnalignedLoad<float>(ptr);
|
||||
ptr += sizeof(float);
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
default: {
|
||||
handle_unusual:
|
||||
if ((tag & 7) == 4 || tag == 0) {
|
||||
ctx->SetLastTag(tag);
|
||||
goto success;
|
||||
}
|
||||
if ((1600u <= tag)) {
|
||||
ptr = _extensions_.ParseField(tag, ptr,
|
||||
internal_default_instance(), &_internal_metadata_, ctx);
|
||||
CHK_(ptr != nullptr);
|
||||
continue;
|
||||
}
|
||||
ptr = UnknownFieldParse(tag,
|
||||
_internal_metadata_.mutable_unknown_fields<std::string>(),
|
||||
ptr, ctx);
|
||||
CHK_(ptr != nullptr);
|
||||
continue;
|
||||
}
|
||||
} // switch
|
||||
} // while
|
||||
success:
|
||||
_has_bits_.Or(has_bits);
|
||||
return ptr;
|
||||
failure:
|
||||
ptr = nullptr;
|
||||
goto success;
|
||||
#undef CHK_
|
||||
}
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::uint8* SentencePieceText::_InternalSerialize(
|
||||
::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const {
|
||||
// @@protoc_insertion_point(serialize_to_array_start:sentencepiece.SentencePieceText)
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
cached_has_bits = _has_bits_[0];
|
||||
// optional string text = 1;
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
target = stream->WriteStringMaybeAliased(
|
||||
1, this->_internal_text(), target);
|
||||
}
|
||||
|
||||
// repeated .sentencepiece.SentencePieceText.SentencePiece pieces = 2;
|
||||
for (unsigned int i = 0,
|
||||
n = static_cast<unsigned int>(this->_internal_pieces_size()); i < n; i++) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::
|
||||
InternalWriteMessage(2, this->_internal_pieces(i), target, stream);
|
||||
}
|
||||
|
||||
// optional float score = 3;
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::WriteFloatToArray(3, this->_internal_score(), target);
|
||||
}
|
||||
|
||||
// Extension range [200, 536870912)
|
||||
target = _extensions_._InternalSerialize(
|
||||
200, 536870912, target, stream);
|
||||
|
||||
if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) {
|
||||
target = stream->WriteRaw(_internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).data(),
|
||||
static_cast<int>(_internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).size()), target);
|
||||
}
|
||||
// @@protoc_insertion_point(serialize_to_array_end:sentencepiece.SentencePieceText)
|
||||
return target;
|
||||
}
|
||||
|
||||
size_t SentencePieceText::ByteSizeLong() const {
|
||||
// @@protoc_insertion_point(message_byte_size_start:sentencepiece.SentencePieceText)
|
||||
size_t total_size = 0;
|
||||
|
||||
total_size += _extensions_.ByteSize();
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
// Prevent compiler warnings about cached_has_bits being unused
|
||||
(void) cached_has_bits;
|
||||
|
||||
// repeated .sentencepiece.SentencePieceText.SentencePiece pieces = 2;
|
||||
total_size += 1UL * this->_internal_pieces_size();
|
||||
for (const auto& msg : this->pieces_) {
|
||||
total_size +=
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg);
|
||||
}
|
||||
|
||||
cached_has_bits = _has_bits_[0];
|
||||
if (cached_has_bits & 0x00000003u) {
|
||||
// optional string text = 1;
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
total_size += 1 +
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::StringSize(
|
||||
this->_internal_text());
|
||||
}
|
||||
|
||||
// optional float score = 3;
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
total_size += 1 + 4;
|
||||
}
|
||||
|
||||
}
|
||||
if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) {
|
||||
total_size += _internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).size();
|
||||
}
|
||||
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size);
|
||||
SetCachedSize(cached_size);
|
||||
return total_size;
|
||||
}
|
||||
|
||||
void SentencePieceText::CheckTypeAndMergeFrom(
|
||||
const ::PROTOBUF_NAMESPACE_ID::MessageLite& from) {
|
||||
MergeFrom(*::PROTOBUF_NAMESPACE_ID::internal::DownCast<const SentencePieceText*>(
|
||||
&from));
|
||||
}
|
||||
|
||||
void SentencePieceText::MergeFrom(const SentencePieceText& from) {
|
||||
// @@protoc_insertion_point(class_specific_merge_from_start:sentencepiece.SentencePieceText)
|
||||
GOOGLE_DCHECK_NE(&from, this);
|
||||
_extensions_.MergeFrom(from._extensions_);
|
||||
_internal_metadata_.MergeFrom<std::string>(from._internal_metadata_);
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
pieces_.MergeFrom(from.pieces_);
|
||||
cached_has_bits = from._has_bits_[0];
|
||||
if (cached_has_bits & 0x00000003u) {
|
||||
if (cached_has_bits & 0x00000001u) {
|
||||
_internal_set_text(from._internal_text());
|
||||
}
|
||||
if (cached_has_bits & 0x00000002u) {
|
||||
score_ = from.score_;
|
||||
}
|
||||
_has_bits_[0] |= cached_has_bits;
|
||||
}
|
||||
}
|
||||
|
||||
void SentencePieceText::CopyFrom(const SentencePieceText& from) {
|
||||
// @@protoc_insertion_point(class_specific_copy_from_start:sentencepiece.SentencePieceText)
|
||||
if (&from == this) return;
|
||||
Clear();
|
||||
MergeFrom(from);
|
||||
}
|
||||
|
||||
bool SentencePieceText::IsInitialized() const {
|
||||
if (!_extensions_.IsInitialized()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!::PROTOBUF_NAMESPACE_ID::internal::AllAreInitialized(pieces_)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void SentencePieceText::InternalSwap(SentencePieceText* other) {
|
||||
using std::swap;
|
||||
_extensions_.Swap(&other->_extensions_);
|
||||
_internal_metadata_.Swap<std::string>(&other->_internal_metadata_);
|
||||
swap(_has_bits_[0], other->_has_bits_[0]);
|
||||
pieces_.InternalSwap(&other->pieces_);
|
||||
text_.Swap(&other->text_, &::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), GetArena());
|
||||
swap(score_, other->score_);
|
||||
}
|
||||
|
||||
std::string SentencePieceText::GetTypeName() const {
|
||||
return "sentencepiece.SentencePieceText";
|
||||
}
|
||||
|
||||
|
||||
// ===================================================================
|
||||
|
||||
class NBestSentencePieceText::_Internal {
|
||||
public:
|
||||
};
|
||||
|
||||
NBestSentencePieceText::NBestSentencePieceText(::PROTOBUF_NAMESPACE_ID::Arena* arena)
|
||||
: ::PROTOBUF_NAMESPACE_ID::MessageLite(arena),
|
||||
nbests_(arena) {
|
||||
SharedCtor();
|
||||
RegisterArenaDtor(arena);
|
||||
// @@protoc_insertion_point(arena_constructor:sentencepiece.NBestSentencePieceText)
|
||||
}
|
||||
NBestSentencePieceText::NBestSentencePieceText(const NBestSentencePieceText& from)
|
||||
: ::PROTOBUF_NAMESPACE_ID::MessageLite(),
|
||||
nbests_(from.nbests_) {
|
||||
_internal_metadata_.MergeFrom<std::string>(from._internal_metadata_);
|
||||
// @@protoc_insertion_point(copy_constructor:sentencepiece.NBestSentencePieceText)
|
||||
}
|
||||
|
||||
void NBestSentencePieceText::SharedCtor() {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&scc_info_NBestSentencePieceText_sentencepiece_2eproto.base);
|
||||
}
|
||||
|
||||
NBestSentencePieceText::~NBestSentencePieceText() {
|
||||
// @@protoc_insertion_point(destructor:sentencepiece.NBestSentencePieceText)
|
||||
SharedDtor();
|
||||
_internal_metadata_.Delete<std::string>();
|
||||
}
|
||||
|
||||
void NBestSentencePieceText::SharedDtor() {
|
||||
GOOGLE_DCHECK(GetArena() == nullptr);
|
||||
}
|
||||
|
||||
void NBestSentencePieceText::ArenaDtor(void* object) {
|
||||
NBestSentencePieceText* _this = reinterpret_cast< NBestSentencePieceText* >(object);
|
||||
(void)_this;
|
||||
}
|
||||
void NBestSentencePieceText::RegisterArenaDtor(::PROTOBUF_NAMESPACE_ID::Arena*) {
|
||||
}
|
||||
void NBestSentencePieceText::SetCachedSize(int size) const {
|
||||
_cached_size_.Set(size);
|
||||
}
|
||||
const NBestSentencePieceText& NBestSentencePieceText::default_instance() {
|
||||
::PROTOBUF_NAMESPACE_ID::internal::InitSCC(&::scc_info_NBestSentencePieceText_sentencepiece_2eproto.base);
|
||||
return *internal_default_instance();
|
||||
}
|
||||
|
||||
|
||||
void NBestSentencePieceText::Clear() {
|
||||
// @@protoc_insertion_point(message_clear_start:sentencepiece.NBestSentencePieceText)
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
// Prevent compiler warnings about cached_has_bits being unused
|
||||
(void) cached_has_bits;
|
||||
|
||||
nbests_.Clear();
|
||||
_internal_metadata_.Clear<std::string>();
|
||||
}
|
||||
|
||||
const char* NBestSentencePieceText::_InternalParse(const char* ptr, ::PROTOBUF_NAMESPACE_ID::internal::ParseContext* ctx) {
|
||||
#define CHK_(x) if (PROTOBUF_PREDICT_FALSE(!(x))) goto failure
|
||||
while (!ctx->Done(&ptr)) {
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 tag;
|
||||
ptr = ::PROTOBUF_NAMESPACE_ID::internal::ReadTag(ptr, &tag);
|
||||
CHK_(ptr);
|
||||
switch (tag >> 3) {
|
||||
// repeated .sentencepiece.SentencePieceText nbests = 1;
|
||||
case 1:
|
||||
if (PROTOBUF_PREDICT_TRUE(static_cast<::PROTOBUF_NAMESPACE_ID::uint8>(tag) == 10)) {
|
||||
ptr -= 1;
|
||||
do {
|
||||
ptr += 1;
|
||||
ptr = ctx->ParseMessage(_internal_add_nbests(), ptr);
|
||||
CHK_(ptr);
|
||||
if (!ctx->DataAvailable(ptr)) break;
|
||||
} while (::PROTOBUF_NAMESPACE_ID::internal::ExpectTag<10>(ptr));
|
||||
} else goto handle_unusual;
|
||||
continue;
|
||||
default: {
|
||||
handle_unusual:
|
||||
if ((tag & 7) == 4 || tag == 0) {
|
||||
ctx->SetLastTag(tag);
|
||||
goto success;
|
||||
}
|
||||
ptr = UnknownFieldParse(tag,
|
||||
_internal_metadata_.mutable_unknown_fields<std::string>(),
|
||||
ptr, ctx);
|
||||
CHK_(ptr != nullptr);
|
||||
continue;
|
||||
}
|
||||
} // switch
|
||||
} // while
|
||||
success:
|
||||
return ptr;
|
||||
failure:
|
||||
ptr = nullptr;
|
||||
goto success;
|
||||
#undef CHK_
|
||||
}
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::uint8* NBestSentencePieceText::_InternalSerialize(
|
||||
::PROTOBUF_NAMESPACE_ID::uint8* target, ::PROTOBUF_NAMESPACE_ID::io::EpsCopyOutputStream* stream) const {
|
||||
// @@protoc_insertion_point(serialize_to_array_start:sentencepiece.NBestSentencePieceText)
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
// repeated .sentencepiece.SentencePieceText nbests = 1;
|
||||
for (unsigned int i = 0,
|
||||
n = static_cast<unsigned int>(this->_internal_nbests_size()); i < n; i++) {
|
||||
target = stream->EnsureSpace(target);
|
||||
target = ::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::
|
||||
InternalWriteMessage(1, this->_internal_nbests(i), target, stream);
|
||||
}
|
||||
|
||||
if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) {
|
||||
target = stream->WriteRaw(_internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).data(),
|
||||
static_cast<int>(_internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).size()), target);
|
||||
}
|
||||
// @@protoc_insertion_point(serialize_to_array_end:sentencepiece.NBestSentencePieceText)
|
||||
return target;
|
||||
}
|
||||
|
||||
size_t NBestSentencePieceText::ByteSizeLong() const {
|
||||
// @@protoc_insertion_point(message_byte_size_start:sentencepiece.NBestSentencePieceText)
|
||||
size_t total_size = 0;
|
||||
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
// Prevent compiler warnings about cached_has_bits being unused
|
||||
(void) cached_has_bits;
|
||||
|
||||
// repeated .sentencepiece.SentencePieceText nbests = 1;
|
||||
total_size += 1UL * this->_internal_nbests_size();
|
||||
for (const auto& msg : this->nbests_) {
|
||||
total_size +=
|
||||
::PROTOBUF_NAMESPACE_ID::internal::WireFormatLite::MessageSize(msg);
|
||||
}
|
||||
|
||||
if (PROTOBUF_PREDICT_FALSE(_internal_metadata_.have_unknown_fields())) {
|
||||
total_size += _internal_metadata_.unknown_fields<std::string>(::PROTOBUF_NAMESPACE_ID::internal::GetEmptyString).size();
|
||||
}
|
||||
int cached_size = ::PROTOBUF_NAMESPACE_ID::internal::ToCachedSize(total_size);
|
||||
SetCachedSize(cached_size);
|
||||
return total_size;
|
||||
}
|
||||
|
||||
void NBestSentencePieceText::CheckTypeAndMergeFrom(
|
||||
const ::PROTOBUF_NAMESPACE_ID::MessageLite& from) {
|
||||
MergeFrom(*::PROTOBUF_NAMESPACE_ID::internal::DownCast<const NBestSentencePieceText*>(
|
||||
&from));
|
||||
}
|
||||
|
||||
void NBestSentencePieceText::MergeFrom(const NBestSentencePieceText& from) {
|
||||
// @@protoc_insertion_point(class_specific_merge_from_start:sentencepiece.NBestSentencePieceText)
|
||||
GOOGLE_DCHECK_NE(&from, this);
|
||||
_internal_metadata_.MergeFrom<std::string>(from._internal_metadata_);
|
||||
::PROTOBUF_NAMESPACE_ID::uint32 cached_has_bits = 0;
|
||||
(void) cached_has_bits;
|
||||
|
||||
nbests_.MergeFrom(from.nbests_);
|
||||
}
|
||||
|
||||
void NBestSentencePieceText::CopyFrom(const NBestSentencePieceText& from) {
|
||||
// @@protoc_insertion_point(class_specific_copy_from_start:sentencepiece.NBestSentencePieceText)
|
||||
if (&from == this) return;
|
||||
Clear();
|
||||
MergeFrom(from);
|
||||
}
|
||||
|
||||
bool NBestSentencePieceText::IsInitialized() const {
|
||||
if (!::PROTOBUF_NAMESPACE_ID::internal::AllAreInitialized(nbests_)) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
void NBestSentencePieceText::InternalSwap(NBestSentencePieceText* other) {
|
||||
using std::swap;
|
||||
_internal_metadata_.Swap<std::string>(&other->_internal_metadata_);
|
||||
nbests_.InternalSwap(&other->nbests_);
|
||||
}
|
||||
|
||||
std::string NBestSentencePieceText::GetTypeName() const {
|
||||
return "sentencepiece.NBestSentencePieceText";
|
||||
}
|
||||
|
||||
|
||||
// @@protoc_insertion_point(namespace_scope)
|
||||
} // namespace sentencepiece
|
||||
PROTOBUF_NAMESPACE_OPEN
|
||||
template<> PROTOBUF_NOINLINE ::sentencepiece::SentencePieceText_SentencePiece* Arena::CreateMaybeMessage< ::sentencepiece::SentencePieceText_SentencePiece >(Arena* arena) {
|
||||
return Arena::CreateMessageInternal< ::sentencepiece::SentencePieceText_SentencePiece >(arena);
|
||||
}
|
||||
template<> PROTOBUF_NOINLINE ::sentencepiece::SentencePieceText* Arena::CreateMaybeMessage< ::sentencepiece::SentencePieceText >(Arena* arena) {
|
||||
return Arena::CreateMessageInternal< ::sentencepiece::SentencePieceText >(arena);
|
||||
}
|
||||
template<> PROTOBUF_NOINLINE ::sentencepiece::NBestSentencePieceText* Arena::CreateMaybeMessage< ::sentencepiece::NBestSentencePieceText >(Arena* arena) {
|
||||
return Arena::CreateMessageInternal< ::sentencepiece::NBestSentencePieceText >(arena);
|
||||
}
|
||||
PROTOBUF_NAMESPACE_CLOSE
|
||||
|
||||
// @@protoc_insertion_point(global_scope)
|
||||
#include <google/protobuf/port_undef.inc>
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,46 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "char_model.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace character {
|
||||
|
||||
Model::Model(const ModelProto &model_proto) {
|
||||
model_proto_ = &model_proto;
|
||||
InitializePieces();
|
||||
}
|
||||
|
||||
Model::~Model() {}
|
||||
|
||||
EncodeResult Model::Encode(absl::string_view normalized) const {
|
||||
if (!status().ok() || normalized.empty()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// Splits the input into character sequence
|
||||
EncodeResult output;
|
||||
while (!normalized.empty()) {
|
||||
const int mblen = matcher_->PrefixMatch(normalized);
|
||||
absl::string_view w(normalized.data(), mblen);
|
||||
output.emplace_back(w, PieceToId(w));
|
||||
normalized.remove_prefix(mblen);
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace character
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef CHAR_MODEL_H_
|
||||
#define CHAR_MODEL_H_
|
||||
|
||||
#include "model_interface.h"
|
||||
#include "sentencepiece_model.pb.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace character {
|
||||
|
||||
// Tokenize text into character sequence
|
||||
class Model : public ModelInterface {
|
||||
public:
|
||||
explicit Model(const ModelProto &model_proto);
|
||||
~Model() override;
|
||||
|
||||
EncodeResult Encode(absl::string_view normalized) const override;
|
||||
};
|
||||
} // namespace character
|
||||
} // namespace sentencepiece
|
||||
#endif // CHAR_MODEL_H_
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "char_model.h"
|
||||
#include "testharness.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace character {
|
||||
namespace {
|
||||
|
||||
// Space symbol (U+2581)
|
||||
#define WS "\xe2\x96\x81"
|
||||
|
||||
ModelProto MakeBaseModelProto() {
|
||||
ModelProto model_proto;
|
||||
auto *sp1 = model_proto.add_pieces();
|
||||
auto *sp2 = model_proto.add_pieces();
|
||||
auto *sp3 = model_proto.add_pieces();
|
||||
|
||||
sp1->set_type(ModelProto::SentencePiece::UNKNOWN);
|
||||
sp1->set_piece("<unk>");
|
||||
sp2->set_type(ModelProto::SentencePiece::CONTROL);
|
||||
sp2->set_piece("<s>");
|
||||
sp3->set_type(ModelProto::SentencePiece::CONTROL);
|
||||
sp3->set_piece("</s>");
|
||||
|
||||
return model_proto;
|
||||
}
|
||||
|
||||
void AddPiece(ModelProto *model_proto, const std::string &piece,
|
||||
float score = 0.0) {
|
||||
auto *sp = model_proto->add_pieces();
|
||||
sp->set_piece(piece);
|
||||
sp->set_score(score);
|
||||
}
|
||||
|
||||
TEST(ModelTest, EncodeTest) {
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
|
||||
AddPiece(&model_proto, WS, 0.0);
|
||||
AddPiece(&model_proto, "a", 0.1);
|
||||
AddPiece(&model_proto, "b", 0.2);
|
||||
AddPiece(&model_proto, "c", 0.3);
|
||||
AddPiece(&model_proto, "d", 0.4);
|
||||
AddPiece(&model_proto, "ABC", 0.4);
|
||||
model_proto.mutable_pieces(8)->set_type(
|
||||
ModelProto::SentencePiece::USER_DEFINED);
|
||||
|
||||
const Model model(model_proto);
|
||||
|
||||
EncodeResult result;
|
||||
|
||||
result = model.Encode("");
|
||||
EXPECT_TRUE(result.empty());
|
||||
|
||||
result = model.Encode(WS "a" WS "b" WS "c");
|
||||
EXPECT_EQ(6, result.size());
|
||||
EXPECT_EQ(WS, result[0].first);
|
||||
EXPECT_EQ("a", result[1].first);
|
||||
EXPECT_EQ(WS, result[2].first);
|
||||
EXPECT_EQ("b", result[3].first);
|
||||
EXPECT_EQ(WS, result[4].first);
|
||||
EXPECT_EQ("c", result[5].first);
|
||||
|
||||
result = model.Encode(WS "ab" WS "cd" WS "abc");
|
||||
EXPECT_EQ(10, result.size());
|
||||
EXPECT_EQ(WS, result[0].first);
|
||||
EXPECT_EQ("a", result[1].first);
|
||||
EXPECT_EQ("b", result[2].first);
|
||||
EXPECT_EQ(WS, result[3].first);
|
||||
EXPECT_EQ("c", result[4].first);
|
||||
EXPECT_EQ("d", result[5].first);
|
||||
EXPECT_EQ(WS, result[6].first);
|
||||
EXPECT_EQ("a", result[7].first);
|
||||
EXPECT_EQ("b", result[8].first);
|
||||
EXPECT_EQ("c", result[9].first);
|
||||
|
||||
// makes a broken utf-8
|
||||
const std::string broken_utf8 = std::string("あ").substr(0, 1);
|
||||
result = model.Encode(broken_utf8);
|
||||
EXPECT_EQ(1, result.size());
|
||||
EXPECT_EQ(broken_utf8, result[0].first);
|
||||
|
||||
// "ABC" is treated as one piece, as it is USER_DEFINED.
|
||||
result = model.Encode(WS "abABCcd");
|
||||
EXPECT_EQ(6, result.size());
|
||||
EXPECT_EQ(WS, result[0].first);
|
||||
EXPECT_EQ("a", result[1].first);
|
||||
EXPECT_EQ("b", result[2].first);
|
||||
EXPECT_EQ("ABC", result[3].first);
|
||||
EXPECT_EQ("c", result[4].first);
|
||||
EXPECT_EQ("d", result[5].first);
|
||||
}
|
||||
|
||||
TEST(CharModelTest, NotSupportedTest) {
|
||||
ModelProto model_proto = MakeBaseModelProto();
|
||||
const Model model(model_proto);
|
||||
EXPECT_EQ(NBestEncodeResult(), model.NBestEncode("test", 10));
|
||||
EXPECT_EQ(EncodeResult(), model.SampleEncode("test", 0.1));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace character
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "char_model.h"
|
||||
#include "char_model_trainer.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace character {
|
||||
|
||||
util::Status Trainer::Train() {
|
||||
RETURN_IF_ERROR(status());
|
||||
|
||||
CHECK_OR_RETURN(normalizer_spec_.escape_whitespaces());
|
||||
CHECK_EQ_OR_RETURN(TrainerSpec::CHAR, trainer_spec_.model_type());
|
||||
|
||||
RETURN_IF_ERROR(LoadSentences());
|
||||
|
||||
const int vocab_size = trainer_spec_.vocab_size() - meta_pieces_.size();
|
||||
CHECK_GE_OR_RETURN(vocab_size, 0);
|
||||
|
||||
uint64 sum = 0;
|
||||
for (const auto &it : required_chars_) {
|
||||
sum += it.second;
|
||||
}
|
||||
|
||||
const auto logsum = std::log(static_cast<float>(sum));
|
||||
|
||||
CHECK_OR_RETURN(final_pieces_.empty());
|
||||
for (const auto &it : Sorted(required_chars_)) {
|
||||
if (!trainer_spec_.use_all_vocab() &&
|
||||
final_pieces_.size() == static_cast<size_t>(vocab_size)) {
|
||||
break;
|
||||
}
|
||||
final_pieces_.emplace_back(
|
||||
string_util::UnicodeCharToUTF8(it.first),
|
||||
std::log(static_cast<float>(it.second)) - logsum);
|
||||
}
|
||||
|
||||
if (trainer_spec_.use_all_vocab()) {
|
||||
trainer_spec_.set_vocab_size(final_pieces_.size() + meta_pieces_.size());
|
||||
}
|
||||
|
||||
return Save();
|
||||
}
|
||||
} // namespace character
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef CHAR_MODEL_TRAINER_H_
|
||||
#define CHAR_MODEL_TRAINER_H_
|
||||
|
||||
#include "sentencepiece_model.pb.h"
|
||||
#include "trainer_interface.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace character {
|
||||
|
||||
// Trainer class for character model.
|
||||
class Trainer : public TrainerInterface {
|
||||
public:
|
||||
Trainer(const TrainerSpec &trainer_spec,
|
||||
const NormalizerSpec &normalizer_spec,
|
||||
const NormalizerSpec &denormalizer_spec)
|
||||
: TrainerInterface::TrainerInterface(trainer_spec, normalizer_spec,
|
||||
denormalizer_spec) {}
|
||||
|
||||
util::Status Train() override;
|
||||
};
|
||||
} // namespace character
|
||||
} // namespace sentencepiece
|
||||
#endif // CHAR_MODEL_TRAINER_H_
|
||||
|
|
@ -1,82 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "char_model_trainer.h"
|
||||
#include "filesystem.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "testharness.h"
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
#include "third_party/absl/strings/str_join.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace character {
|
||||
namespace {
|
||||
|
||||
// Space symbol (U+2581)
|
||||
#define WS "\xE2\x96\x81"
|
||||
|
||||
std::string RunTrainer(const std::vector<std::string> &input, int size) {
|
||||
const std::string input_file =
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "input");
|
||||
const std::string model_prefix =
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "model");
|
||||
{
|
||||
auto output = filesystem::NewWritableFile(input_file);
|
||||
for (const auto &line : input) {
|
||||
output->WriteLine(line);
|
||||
}
|
||||
}
|
||||
|
||||
TrainerSpec trainer_spec;
|
||||
trainer_spec.set_model_type(TrainerSpec::CHAR);
|
||||
trainer_spec.add_input(input_file);
|
||||
trainer_spec.set_vocab_size(size);
|
||||
trainer_spec.set_model_prefix(model_prefix);
|
||||
|
||||
NormalizerSpec normalizer_spec;
|
||||
normalizer_spec.set_name("identity");
|
||||
|
||||
NormalizerSpec denormalizer_spec;
|
||||
|
||||
Trainer trainer(trainer_spec, normalizer_spec, denormalizer_spec);
|
||||
EXPECT_TRUE(trainer.Train().ok());
|
||||
|
||||
SentencePieceProcessor processor;
|
||||
EXPECT_TRUE(processor.Load(model_prefix + ".model").ok());
|
||||
|
||||
const auto &model = processor.model_proto();
|
||||
std::vector<std::string> pieces;
|
||||
|
||||
// remove <unk>, <s>, </s>
|
||||
for (int i = 3; i < model.pieces_size(); ++i) {
|
||||
pieces.emplace_back(model.pieces(i).piece());
|
||||
}
|
||||
|
||||
return absl::StrJoin(pieces, " ");
|
||||
}
|
||||
|
||||
TEST(TrainerTest, BasicTest) {
|
||||
EXPECT_EQ(WS " a e p n I h l v",
|
||||
RunTrainer({"I have a pen", "I have an apple", "apple pen"}, 100));
|
||||
EXPECT_EQ(WS " a", // <unk>, <s>, </s>, _, a
|
||||
RunTrainer({"I have a pen", "I have an apple", "apple pen"}, 5));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace character
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef COMMON_H_
|
||||
#define COMMON_H_
|
||||
|
||||
#include <stdint.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "config.h"
|
||||
#include "third_party/absl/strings/string_view.h"
|
||||
|
||||
#if defined(_WIN32) && !defined(__CYGWIN__)
|
||||
#define OS_WIN
|
||||
#else
|
||||
#define OS_UNIX
|
||||
#endif
|
||||
|
||||
#ifdef OS_WIN
|
||||
#ifndef NOMINMAX
|
||||
#define NOMINMAX
|
||||
#endif
|
||||
#include <windows.h>
|
||||
#endif
|
||||
|
||||
typedef int8_t int8;
|
||||
typedef int16_t int16;
|
||||
typedef int32_t int32;
|
||||
typedef int64_t int64;
|
||||
typedef uint8_t uint8;
|
||||
typedef uint16_t uint16;
|
||||
typedef uint32_t char32;
|
||||
typedef uint32_t uint32;
|
||||
typedef uint64_t uint64;
|
||||
|
||||
static constexpr uint32 kUnicodeError = 0xFFFD;
|
||||
|
||||
template <typename T, size_t N>
|
||||
char (&ArraySizeHelper(T (&array)[N]))[N];
|
||||
|
||||
#ifndef _MSC_VER
|
||||
template <typename T, size_t N>
|
||||
char (&ArraySizeHelper(const T (&array)[N]))[N];
|
||||
#endif // !_MSC_VER
|
||||
|
||||
#define arraysize(array) (sizeof(ArraySizeHelper(array)))
|
||||
|
||||
#if defined(_FREEBSD)
|
||||
#include <sys/endian.h>
|
||||
#endif
|
||||
#if !defined(__APPLE__) && !defined(_WIN32) && !defined(_FREEBSD)
|
||||
#include <endian.h>
|
||||
#if BYTE_ORDER == __BIG_ENDIAN
|
||||
#define IS_BIG_ENDIAN
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace util {
|
||||
#ifndef OS_WIN
|
||||
inline uint32 Swap32(uint32 x) { return __builtin_bswap32(x); }
|
||||
#endif // OS_WIN
|
||||
} // namespace util
|
||||
|
||||
namespace error {
|
||||
|
||||
void Abort();
|
||||
void Exit(int code);
|
||||
void SetTestCounter(int c);
|
||||
void ResetTestMode();
|
||||
bool GetTestCounter();
|
||||
|
||||
class Die {
|
||||
public:
|
||||
explicit Die(bool die) : die_(die) {}
|
||||
~Die() {
|
||||
std::cerr << std::endl;
|
||||
if (die_) {
|
||||
Abort();
|
||||
}
|
||||
}
|
||||
int operator&(std::ostream &) { return 0; }
|
||||
|
||||
private:
|
||||
bool die_;
|
||||
};
|
||||
} // namespace error
|
||||
|
||||
namespace logging {
|
||||
enum LogSeverity {
|
||||
LOG_INFO = 0,
|
||||
LOG_WARNING = 1,
|
||||
LOG_ERROR = 2,
|
||||
LOG_FATAL = 3,
|
||||
LOG_SEVERITY_SIZE = 4,
|
||||
};
|
||||
|
||||
int GetMinLogLevel();
|
||||
void SetMinLogLevel(int v);
|
||||
|
||||
inline const char *BaseName(const char *path) {
|
||||
#ifdef OS_WIN
|
||||
const char *p = strrchr(path, '\\');
|
||||
#else
|
||||
const char *p = strrchr(path, '/');
|
||||
#endif
|
||||
if (p == nullptr) return path;
|
||||
return p + 1;
|
||||
}
|
||||
} // namespace logging
|
||||
} // namespace sentencepiece
|
||||
|
||||
#define LOG(severity) \
|
||||
(::sentencepiece::logging::GetMinLogLevel() > \
|
||||
::sentencepiece::logging::LOG_##severity) \
|
||||
? 0 \
|
||||
: ::sentencepiece::error::Die( \
|
||||
::sentencepiece::logging::LOG_##severity >= \
|
||||
::sentencepiece::logging::LOG_FATAL) & \
|
||||
std::cerr << ::sentencepiece::logging::BaseName(__FILE__) << "(" \
|
||||
<< __LINE__ << ") " \
|
||||
<< "LOG(" << #severity << ") "
|
||||
|
||||
#define CHECK(condition) \
|
||||
(condition) ? 0 \
|
||||
: ::sentencepiece::error::Die(true) & \
|
||||
std::cerr << ::sentencepiece::logging::BaseName(__FILE__) \
|
||||
<< "(" << __LINE__ << ") [" << #condition \
|
||||
<< "] "
|
||||
|
||||
#define CHECK_STREQ(a, b) CHECK_EQ(std::string(a), std::string(b))
|
||||
#define CHECK_EQ(a, b) CHECK((a) == (b))
|
||||
#define CHECK_NE(a, b) CHECK((a) != (b))
|
||||
#define CHECK_GE(a, b) CHECK((a) >= (b))
|
||||
#define CHECK_LE(a, b) CHECK((a) <= (b))
|
||||
#define CHECK_GT(a, b) CHECK((a) > (b))
|
||||
#define CHECK_LT(a, b) CHECK((a) < (b))
|
||||
|
||||
#define FRIEND_TEST(a, b) friend class a##_Test_##b;
|
||||
|
||||
#define CHECK_OK(expr) \
|
||||
do { \
|
||||
const auto _status = expr; \
|
||||
CHECK(_status.ok()) << _status.ToString(); \
|
||||
} while (0)
|
||||
|
||||
#define CHECK_NOT_OK(expr) \
|
||||
do { \
|
||||
const auto _status = expr; \
|
||||
CHECK(!_status.ok()) << _status.ToString(); \
|
||||
} while (0)
|
||||
|
||||
#define RETURN_IF_ERROR(expr) \
|
||||
do { \
|
||||
const auto _status = expr; \
|
||||
if (!_status.ok()) return _status; \
|
||||
} while (0)
|
||||
|
||||
#endif // COMMON_H_
|
||||
|
|
@ -1,198 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <functional>
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "builder.h"
|
||||
#include "filesystem.h"
|
||||
#include "init.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "third_party/absl/flags/flag.h"
|
||||
#include "third_party/absl/strings/string_view.h"
|
||||
|
||||
using sentencepiece::normalizer::Builder;
|
||||
|
||||
ABSL_FLAG(bool, output_precompiled_header, false,
|
||||
"make normalization_rule.h file");
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
|
||||
std::string ToHexUInt64Array(
|
||||
const std::vector<std::pair<std::string, std::string>> &data,
|
||||
std::vector<size_t> *offset) {
|
||||
std::stringstream os;
|
||||
os.setf(std::ios_base::hex, std::ios_base::basefield);
|
||||
os.setf(std::ios_base::uppercase);
|
||||
os.setf(std::ios_base::right);
|
||||
os.fill('0');
|
||||
os.unsetf(std::ios_base::showbase);
|
||||
|
||||
size_t num = 0;
|
||||
for (const auto &p : data) {
|
||||
const char *begin = p.second.data();
|
||||
const char *end = p.second.data() + p.second.size();
|
||||
|
||||
offset->push_back(num);
|
||||
while (begin < end) {
|
||||
unsigned long long int n = 0;
|
||||
unsigned char *buf = reinterpret_cast<unsigned char *>(&n);
|
||||
const size_t size = std::min<size_t>(end - begin, sizeof(n));
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
buf[i] = static_cast<unsigned char>(begin[i]);
|
||||
}
|
||||
begin += sizeof(n);
|
||||
os << "0x" << std::setw(2 * sizeof(n)) << n << ", ";
|
||||
if (++num % 8 == 0) {
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
std::string ToHexData(absl::string_view data) {
|
||||
const char *begin = data.data();
|
||||
const char *end = data.data() + data.size();
|
||||
constexpr char kHex[] = "0123456789ABCDEF";
|
||||
constexpr size_t kNumOfBytesOnOneLine = 20;
|
||||
|
||||
size_t output_count = 0;
|
||||
std::stringstream os;
|
||||
while (begin < end) {
|
||||
const size_t bucket_size =
|
||||
std::min<size_t>(end - begin, kNumOfBytesOnOneLine -
|
||||
output_count % kNumOfBytesOnOneLine);
|
||||
if (output_count % kNumOfBytesOnOneLine == 0 && bucket_size > 0) {
|
||||
os << "\"";
|
||||
}
|
||||
for (size_t i = 0; i < bucket_size; ++i) {
|
||||
os << "\\x" << kHex[(*begin & 0xF0) >> 4] << kHex[(*begin & 0x0F) >> 0];
|
||||
++begin;
|
||||
}
|
||||
output_count += bucket_size;
|
||||
if (output_count % kNumOfBytesOnOneLine == 0 && bucket_size > 0 &&
|
||||
begin < end) {
|
||||
os << "\"\n";
|
||||
}
|
||||
}
|
||||
os << "\"\n";
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
std::string MakeHeader(
|
||||
const std::vector<std::pair<std::string, std::string>> &data) {
|
||||
constexpr char kHeader[] =
|
||||
R"(#ifndef NORMALIZATION_RULE_H_
|
||||
#define NORMALIZATION_RULE_H_
|
||||
#include <cstdio>
|
||||
namespace sentencepiece {
|
||||
namespace {
|
||||
|
||||
struct BinaryBlob {
|
||||
const char *name;
|
||||
size_t size;
|
||||
const char *data;
|
||||
};
|
||||
|
||||
)";
|
||||
|
||||
constexpr char kFooter[] = R"(
|
||||
} // namespace
|
||||
} // namespace sentencepiece
|
||||
#endif // NORMALIZATION_RULE_H_
|
||||
)";
|
||||
|
||||
std::stringstream os;
|
||||
os << kHeader;
|
||||
|
||||
os << "#if defined(_WIN32) && !defined(__CYGWIN__)\n";
|
||||
os << "constexpr unsigned long long int kNormalizationRules_blob_uint64[] = "
|
||||
"{\n";
|
||||
std::vector<size_t> offset;
|
||||
os << ToHexUInt64Array(data, &offset);
|
||||
CHECK_EQ(offset.size(), data.size());
|
||||
os << "};\n\n";
|
||||
os << "const BinaryBlob kNormalizationRules_blob[] = {\n";
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
os << "{ \"" << data[i].first << "\", " << data[i].second.size() << ", ";
|
||||
os << "reinterpret_cast<const char *>(kNormalizationRules_blob_uint64 + "
|
||||
<< offset[i] << ") },\n";
|
||||
}
|
||||
os << "};\n";
|
||||
os << "#else\n";
|
||||
os << "constexpr BinaryBlob kNormalizationRules_blob[] = {\n";
|
||||
for (size_t i = 0; i < data.size(); ++i) {
|
||||
os << "{ \"" << data[i].first << "\", " << data[i].second.size() << ", ";
|
||||
os << ToHexData(data[i].second) << "},\n";
|
||||
}
|
||||
os << "};\n";
|
||||
os << "#endif\n";
|
||||
|
||||
os << "constexpr size_t kNormalizationRules_size = " << data.size() << ";\n";
|
||||
os << kFooter;
|
||||
|
||||
return os.str();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace sentencepiece
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
sentencepiece::ScopedResourceDestructor cleaner;
|
||||
sentencepiece::ParseCommandLineFlags(argv[0], &argc, &argv, true);
|
||||
|
||||
const std::vector<std::pair<
|
||||
std::string,
|
||||
std::function<sentencepiece::util::Status(Builder::CharsMap *)>>>
|
||||
kRuleList = {{"nfkc", Builder::BuildNFKCMap},
|
||||
{"nmt_nfkc", Builder::BuildNmtNFKCMap},
|
||||
{"nfkc_cf", Builder::BuildNFKC_CFMap},
|
||||
{"nmt_nfkc_cf", Builder::BuildNmtNFKC_CFMap},
|
||||
{"nfkd", Builder::BuildNFKDMap}};
|
||||
|
||||
std::vector<std::pair<std::string, std::string>> data;
|
||||
for (const auto &p : kRuleList) {
|
||||
Builder::CharsMap normalized_map;
|
||||
CHECK_OK(p.second(&normalized_map));
|
||||
|
||||
// Write Header.
|
||||
std::string index;
|
||||
CHECK_OK(Builder::CompileCharsMap(normalized_map, &index));
|
||||
|
||||
// Write TSV file.
|
||||
CHECK_OK(Builder::SaveCharsMap(p.first + ".tsv", normalized_map));
|
||||
|
||||
// Do not make NFKD map as it is optionally created.
|
||||
if (p.first.find("nfkd") != std::string::npos) continue;
|
||||
|
||||
data.emplace_back(p.first, index);
|
||||
}
|
||||
|
||||
if (absl::GetFlag(FLAGS_output_precompiled_header)) {
|
||||
constexpr char kPrecompiledHeaderFileName[] = "normalization_rule.h";
|
||||
auto output =
|
||||
sentencepiece::filesystem::NewWritableFile(kPrecompiledHeaderFileName);
|
||||
CHECK_OK(output->status());
|
||||
output->Write(sentencepiece::MakeHeader(data));
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -1,160 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#include "common.h"
|
||||
#include "init.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
|
||||
#ifdef _USE_EXTERNAL_ABSL
|
||||
// Naive workaround to define minloglevel on external absl package.
|
||||
// We want to define them in other cc file.
|
||||
#include "third_party/absl/flags/flag.h"
|
||||
#include "third_party/absl/flags/parse.h"
|
||||
ABSL_FLAG(int32, minloglevel, 0,
|
||||
"Messages logged at a lower level than this don't actually.");
|
||||
#endif
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace error {
|
||||
int gTestCounter = 0;
|
||||
|
||||
void Abort() {
|
||||
if (GetTestCounter() == 1) {
|
||||
SetTestCounter(2);
|
||||
} else {
|
||||
std::cerr << "Program terminated with an unrecoverable error." << std::endl;
|
||||
ShutdownLibrary();
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
void Exit(int code) {
|
||||
if (GetTestCounter() == 1) {
|
||||
SetTestCounter(2);
|
||||
} else {
|
||||
ShutdownLibrary();
|
||||
exit(code);
|
||||
}
|
||||
}
|
||||
|
||||
void SetTestCounter(int c) { gTestCounter = c; }
|
||||
bool GetTestCounter() { return gTestCounter; }
|
||||
} // namespace error
|
||||
|
||||
namespace util {
|
||||
|
||||
Status::Status() {}
|
||||
Status::~Status() {}
|
||||
|
||||
struct Status::Rep {
|
||||
StatusCode code;
|
||||
std::string error_message;
|
||||
};
|
||||
|
||||
Status::Status(StatusCode code, absl::string_view error_message)
|
||||
: rep_(new Rep) {
|
||||
rep_->code = code;
|
||||
rep_->error_message = std::string(error_message);
|
||||
}
|
||||
|
||||
Status::Status(const Status& s)
|
||||
: rep_((s.rep_ == nullptr) ? nullptr : new Rep(*s.rep_)) {}
|
||||
|
||||
void Status::operator=(const Status& s) {
|
||||
if (rep_ != s.rep_)
|
||||
rep_.reset((s.rep_ == nullptr) ? nullptr : new Rep(*s.rep_));
|
||||
}
|
||||
|
||||
bool Status::operator==(const Status& s) const { return (rep_ == s.rep_); }
|
||||
|
||||
bool Status::operator!=(const Status& s) const { return (rep_ != s.rep_); }
|
||||
|
||||
const char* Status::error_message() const {
|
||||
return ok() ? "" : rep_->error_message.c_str();
|
||||
}
|
||||
|
||||
void Status::set_error_message(const char* str) {
|
||||
if (rep_ == nullptr) rep_.reset(new Rep);
|
||||
rep_->error_message = str;
|
||||
}
|
||||
|
||||
StatusCode Status::code() const { return ok() ? StatusCode::kOk : rep_->code; }
|
||||
|
||||
std::string Status::ToString() const {
|
||||
if (rep_ == nullptr) return "OK";
|
||||
|
||||
std::string result;
|
||||
switch (code()) {
|
||||
case StatusCode::kCancelled:
|
||||
result = "Cancelled";
|
||||
break;
|
||||
case StatusCode::kUnknown:
|
||||
result = "Unknown";
|
||||
break;
|
||||
case StatusCode::kInvalidArgument:
|
||||
result = "Invalid argument";
|
||||
break;
|
||||
case StatusCode::kDeadlineExceeded:
|
||||
result = "Deadline exceeded";
|
||||
break;
|
||||
case StatusCode::kNotFound:
|
||||
result = "Not found";
|
||||
break;
|
||||
case StatusCode::kAlreadyExists:
|
||||
result = "Already exists";
|
||||
break;
|
||||
case StatusCode::kPermissionDenied:
|
||||
result = "Permission denied";
|
||||
break;
|
||||
case StatusCode::kResourceExhausted:
|
||||
result = "Unauthenticated";
|
||||
break;
|
||||
case StatusCode::kFailedPrecondition:
|
||||
result = "Failed precondition";
|
||||
break;
|
||||
case StatusCode::kAborted:
|
||||
result = "Aborted";
|
||||
break;
|
||||
case StatusCode::kOutOfRange:
|
||||
result = "Out of range";
|
||||
break;
|
||||
case StatusCode::kUnimplemented:
|
||||
result = "Unimplemented";
|
||||
break;
|
||||
case StatusCode::kInternal:
|
||||
result = "Internal";
|
||||
break;
|
||||
case StatusCode::kUnavailable:
|
||||
result = "Unavailable";
|
||||
break;
|
||||
case StatusCode::kDataLoss:
|
||||
result = "Data loss";
|
||||
break;
|
||||
case StatusCode::kUnauthenticated:
|
||||
result = "Unauthenticated";
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
result += ": ";
|
||||
result += rep_->error_message;
|
||||
return result;
|
||||
}
|
||||
|
||||
void Status::IgnoreError() {}
|
||||
|
||||
} // namespace util
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,116 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "filesystem.h"
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "util.h"
|
||||
|
||||
#if defined(OS_WIN) && defined(UNICODE) && defined(_UNICODE)
|
||||
#define WPATH(path) (::sentencepiece::util::Utf8ToWide(path).c_str())
|
||||
#else
|
||||
#define WPATH(path) (path.data())
|
||||
#endif
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace filesystem {
|
||||
|
||||
class PosixReadableFile : public ReadableFile {
|
||||
public:
|
||||
PosixReadableFile(absl::string_view filename, bool is_binary = false)
|
||||
: is_(filename.empty()
|
||||
? &std::cin
|
||||
: new std::ifstream(WPATH(filename),
|
||||
is_binary ? std::ios::binary | std::ios::in
|
||||
: std::ios::in)) {
|
||||
if (!*is_)
|
||||
status_ = util::StatusBuilder(util::StatusCode::kNotFound, GTL_LOC)
|
||||
<< "\"" << filename.data() << "\": " << util::StrError(errno);
|
||||
}
|
||||
|
||||
~PosixReadableFile() {
|
||||
if (is_ != &std::cin) delete is_;
|
||||
}
|
||||
|
||||
util::Status status() const { return status_; }
|
||||
|
||||
bool ReadLine(std::string *line) {
|
||||
return static_cast<bool>(std::getline(*is_, *line));
|
||||
}
|
||||
|
||||
bool ReadAll(std::string *line) {
|
||||
if (is_ == &std::cin) {
|
||||
LOG(ERROR) << "ReadAll is not supported for stdin.";
|
||||
return false;
|
||||
}
|
||||
line->assign(std::istreambuf_iterator<char>(*is_),
|
||||
std::istreambuf_iterator<char>());
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
util::Status status_;
|
||||
std::istream *is_;
|
||||
};
|
||||
|
||||
class PosixWritableFile : public WritableFile {
|
||||
public:
|
||||
PosixWritableFile(absl::string_view filename, bool is_binary = false)
|
||||
: os_(filename.empty()
|
||||
? &std::cout
|
||||
: new std::ofstream(WPATH(filename),
|
||||
is_binary ? std::ios::binary | std::ios::out
|
||||
: std::ios::out)) {
|
||||
if (!*os_)
|
||||
status_ =
|
||||
util::StatusBuilder(util::StatusCode::kPermissionDenied, GTL_LOC)
|
||||
<< "\"" << filename.data() << "\": " << util::StrError(errno);
|
||||
}
|
||||
|
||||
~PosixWritableFile() {
|
||||
if (os_ != &std::cout) delete os_;
|
||||
}
|
||||
|
||||
util::Status status() const { return status_; }
|
||||
|
||||
bool Write(absl::string_view text) {
|
||||
os_->write(text.data(), text.size());
|
||||
return os_->good();
|
||||
}
|
||||
|
||||
bool WriteLine(absl::string_view text) { return Write(text) && Write("\n"); }
|
||||
|
||||
private:
|
||||
util::Status status_;
|
||||
std::ostream *os_;
|
||||
};
|
||||
|
||||
using DefaultReadableFile = PosixReadableFile;
|
||||
using DefaultWritableFile = PosixWritableFile;
|
||||
|
||||
std::unique_ptr<ReadableFile> NewReadableFile(absl::string_view filename,
|
||||
bool is_binary) {
|
||||
return std::make_unique<DefaultReadableFile>(filename, is_binary);
|
||||
}
|
||||
|
||||
std::unique_ptr<WritableFile> NewWritableFile(absl::string_view filename,
|
||||
bool is_binary) {
|
||||
return std::make_unique<DefaultWritableFile>(filename, is_binary);
|
||||
}
|
||||
|
||||
} // namespace filesystem
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,59 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef FILESYSTEM_H_
|
||||
#define FILESYSTEM_H_
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "common.h"
|
||||
#include "sentencepiece_processor.h"
|
||||
#include "third_party/absl/strings/string_view.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace filesystem {
|
||||
class ReadableFile {
|
||||
public:
|
||||
ReadableFile() {}
|
||||
explicit ReadableFile(absl::string_view filename, bool is_binary = false) {}
|
||||
virtual ~ReadableFile() {}
|
||||
|
||||
virtual util::Status status() const = 0;
|
||||
virtual bool ReadLine(std::string *line) = 0;
|
||||
virtual bool ReadAll(std::string *line) = 0;
|
||||
};
|
||||
|
||||
class WritableFile {
|
||||
public:
|
||||
WritableFile() {}
|
||||
explicit WritableFile(absl::string_view filename, bool is_binary = false) {}
|
||||
virtual ~WritableFile() {}
|
||||
|
||||
virtual util::Status status() const = 0;
|
||||
virtual bool Write(absl::string_view text) = 0;
|
||||
virtual bool WriteLine(absl::string_view text) = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<ReadableFile> NewReadableFile(absl::string_view filename,
|
||||
bool is_binary = false);
|
||||
std::unique_ptr<WritableFile> NewWritableFile(absl::string_view filename,
|
||||
bool is_binary = false);
|
||||
|
||||
} // namespace filesystem
|
||||
} // namespace sentencepiece
|
||||
#endif // FILESYSTEM_H_
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "filesystem.h"
|
||||
#include "testharness.h"
|
||||
#include "third_party/absl/strings/str_cat.h"
|
||||
#include "util.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
|
||||
TEST(UtilTest, FilesystemTest) {
|
||||
const std::vector<std::string> kData = {
|
||||
"This"
|
||||
"is"
|
||||
"a"
|
||||
"test"};
|
||||
|
||||
{
|
||||
auto output = filesystem::NewWritableFile(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
|
||||
for (size_t i = 0; i < kData.size(); ++i) {
|
||||
output->WriteLine(kData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
auto input = filesystem::NewReadableFile(
|
||||
util::JoinPath(absl::GetFlag(FLAGS_test_tmpdir), "test_file"));
|
||||
std::string line;
|
||||
for (size_t i = 0; i < kData.size(); ++i) {
|
||||
EXPECT_TRUE(input->ReadLine(&line));
|
||||
EXPECT_EQ(kData[i], line);
|
||||
}
|
||||
EXPECT_FALSE(input->ReadLine(&line));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(UtilTest, FilesystemInvalidFileTest) {
|
||||
auto input = filesystem::NewReadableFile("__UNKNOWN__FILE__");
|
||||
EXPECT_FALSE(input->status().ok());
|
||||
}
|
||||
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,90 +0,0 @@
|
|||
// Copyright 2018 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef FREELIST_H_
|
||||
#define FREELIST_H_
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace model {
|
||||
|
||||
// Simple FreeList that allocates a chunk of T at once.
|
||||
template <class T>
|
||||
class FreeList {
|
||||
public:
|
||||
FreeList() = delete;
|
||||
explicit FreeList(size_t chunk_size) : chunk_size_(chunk_size) {}
|
||||
virtual ~FreeList() {
|
||||
for (auto& chunk : freelist_) delete[] chunk;
|
||||
}
|
||||
|
||||
// `Free` doesn't free the object but reuse the allocated memory chunks.
|
||||
void Free() {
|
||||
const int size = std::min<int>(chunk_index_ + 1, freelist_.size());
|
||||
for (int i = 0; i < size; ++i) {
|
||||
T* chunk = freelist_[i];
|
||||
memset(static_cast<void*>(chunk), 0, sizeof(*chunk) * chunk_size_);
|
||||
}
|
||||
chunk_index_ = 0;
|
||||
element_index_ = 0;
|
||||
}
|
||||
|
||||
// Returns the number of allocated elements.
|
||||
size_t size() const { return chunk_size_ * chunk_index_ + element_index_; }
|
||||
|
||||
void swap(FreeList<T>& other) {
|
||||
std::swap(freelist_, other.freelist_);
|
||||
std::swap(element_index_, other.element_index_);
|
||||
std::swap(chunk_index_, other.chunk_index_);
|
||||
std::swap(chunk_size_, other.chunk_size_);
|
||||
}
|
||||
|
||||
// Returns the element as an array.
|
||||
T* operator[](size_t index) const {
|
||||
return freelist_[index / chunk_size_] + index % chunk_size_;
|
||||
}
|
||||
|
||||
// Allocates new element.
|
||||
T* Allocate() {
|
||||
if (element_index_ >= chunk_size_) {
|
||||
++chunk_index_;
|
||||
element_index_ = 0;
|
||||
}
|
||||
|
||||
if (chunk_index_ == freelist_.size()) {
|
||||
T* chunk = new T[chunk_size_];
|
||||
memset(static_cast<void*>(chunk), 0, sizeof(*chunk) * chunk_size_);
|
||||
freelist_.push_back(chunk);
|
||||
}
|
||||
|
||||
T* result = freelist_[chunk_index_] + element_index_;
|
||||
++element_index_;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<T*> freelist_;
|
||||
|
||||
// The last element is stored at freelist_[chunk_index_][element_index_]
|
||||
size_t element_index_ = 0;
|
||||
size_t chunk_index_ = 0;
|
||||
size_t chunk_size_ = 0; // Do not modify except in swap()
|
||||
};
|
||||
} // namespace model
|
||||
} // namespace sentencepiece
|
||||
#endif // FREELIST_H_
|
||||
|
|
@ -1,51 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "freelist.h"
|
||||
#include "testharness.h"
|
||||
|
||||
namespace sentencepiece {
|
||||
namespace model {
|
||||
|
||||
TEST(FreeListTest, BasicTest) {
|
||||
FreeList<int> l(5);
|
||||
EXPECT_EQ(0, l.size());
|
||||
|
||||
constexpr size_t kSize = 32;
|
||||
|
||||
for (size_t i = 0; i < kSize; ++i) {
|
||||
int *n = l.Allocate();
|
||||
EXPECT_EQ(0, *n);
|
||||
*n = i;
|
||||
}
|
||||
|
||||
FreeList<int> l2(3); // Test swap()
|
||||
l.swap(l2);
|
||||
|
||||
EXPECT_EQ(kSize, l2.size());
|
||||
for (size_t i = 0; i < kSize; ++i) {
|
||||
EXPECT_EQ(i, *l2[i]);
|
||||
}
|
||||
|
||||
l2.Free();
|
||||
EXPECT_EQ(0, l2.size());
|
||||
|
||||
// Zero-initialized after `Free`.
|
||||
for (size_t i = 0; i < kSize; ++i) {
|
||||
int *n = l2.Allocate();
|
||||
EXPECT_EQ(0, *n);
|
||||
}
|
||||
}
|
||||
} // namespace model
|
||||
} // namespace sentencepiece
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#ifndef INIT_H_
|
||||
#define INIT_H_
|
||||
|
||||
#include "common.h"
|
||||
#include "third_party/absl/flags/flag.h"
|
||||
#include "third_party/absl/flags/parse.h"
|
||||
|
||||
#ifdef _USE_EXTERNAL_PROTOBUF
|
||||
#include "google/protobuf/message_lite.h"
|
||||
#else
|
||||
#include "third_party/protobuf-lite/google/protobuf/message_lite.h"
|
||||
#endif
|
||||
|
||||
ABSL_DECLARE_FLAG(int32, minloglevel);
|
||||
|
||||
namespace sentencepiece {
|
||||
inline void ParseCommandLineFlags(const char *usage, int *argc, char ***argv,
|
||||
bool remove_arg = true) {
|
||||
const auto unused_args = absl::ParseCommandLine(*argc, *argv);
|
||||
|
||||
if (remove_arg) {
|
||||
char **argv_val = *argv;
|
||||
*argv = argv_val = argv_val + *argc - unused_args.size();
|
||||
std::copy(unused_args.begin(), unused_args.end(), argv_val);
|
||||
*argc = static_cast<int>(unused_args.size());
|
||||
}
|
||||
|
||||
logging::SetMinLogLevel(absl::GetFlag(FLAGS_minloglevel));
|
||||
}
|
||||
|
||||
inline void ShutdownLibrary() {
|
||||
google::protobuf::ShutdownProtobufLibrary();
|
||||
#ifdef HAS_ABSL_CLEANUP_FLAGS
|
||||
absl::CleanupFlags();
|
||||
#endif
|
||||
}
|
||||
|
||||
class ScopedResourceDestructor {
|
||||
public:
|
||||
ScopedResourceDestructor() {}
|
||||
~ScopedResourceDestructor() { ShutdownLibrary(); }
|
||||
};
|
||||
|
||||
} // namespace sentencepiece
|
||||
|
||||
#endif // INIT_H_
|
||||
|
|
@ -1,147 +0,0 @@
|
|||
// Copyright 2016 Google Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.!
|
||||
|
||||
#include "init.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "testharness.h"
|
||||
|
||||
ABSL_FLAG(int32, int32_f, 10, "int32_flags");
|
||||
ABSL_FLAG(bool, bool_f, false, "bool_flags");
|
||||
ABSL_FLAG(int64, int64_f, 9223372036854775807LL, "int64_flags");
|
||||
ABSL_FLAG(uint64, uint64_f, 18446744073709551615ULL, "uint64_flags");
|
||||
ABSL_FLAG(double, double_f, 40.0, "double_flags");
|
||||
ABSL_FLAG(std::string, string_f, "str", "string_flags");
|
||||
|
||||
ABSL_DECLARE_FLAG(bool, help);
|
||||
ABSL_DECLARE_FLAG(bool, version);
|
||||
|
||||
using sentencepiece::ParseCommandLineFlags;
|
||||
|
||||
namespace absl {
|
||||
TEST(FlagsTest, DefaultValueTest) {
|
||||
EXPECT_EQ(10, absl::GetFlag(FLAGS_int32_f));
|
||||
EXPECT_EQ(false, absl::GetFlag(FLAGS_bool_f));
|
||||
EXPECT_EQ(9223372036854775807LL, absl::GetFlag(FLAGS_int64_f));
|
||||
EXPECT_EQ(18446744073709551615ULL, absl::GetFlag(FLAGS_uint64_f));
|
||||
EXPECT_EQ(40.0, absl::GetFlag(FLAGS_double_f));
|
||||
EXPECT_EQ("str", absl::GetFlag(FLAGS_string_f));
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsTest) {
|
||||
const char *kFlags[] = {"program", "--int32_f=100", "other1",
|
||||
"--bool_f=true", "--int64_f=200", "--uint64_f=300",
|
||||
"--double_f=400", "--string_f=foo", "other2",
|
||||
"other3"};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
ParseCommandLineFlags(kFlags[0], &argc, &argv);
|
||||
|
||||
EXPECT_EQ(100, absl::GetFlag(FLAGS_int32_f));
|
||||
EXPECT_EQ(true, absl::GetFlag(FLAGS_bool_f));
|
||||
EXPECT_EQ(200, absl::GetFlag(FLAGS_int64_f));
|
||||
EXPECT_EQ(300, absl::GetFlag(FLAGS_uint64_f));
|
||||
EXPECT_EQ(400.0, absl::GetFlag(FLAGS_double_f));
|
||||
EXPECT_EQ("foo", absl::GetFlag(FLAGS_string_f));
|
||||
EXPECT_EQ(4, argc);
|
||||
EXPECT_EQ("program", std::string(argv[0]));
|
||||
EXPECT_EQ("other1", std::string(argv[1]));
|
||||
EXPECT_EQ("other2", std::string(argv[2]));
|
||||
EXPECT_EQ("other3", std::string(argv[3]));
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsTest2) {
|
||||
const char *kFlags[] = {"program", "--int32_f", "500",
|
||||
"-int64_f=600", "-uint64_f", "700",
|
||||
"--bool_f=FALSE"};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
ParseCommandLineFlags(kFlags[0], &argc, &argv);
|
||||
|
||||
EXPECT_EQ(500, absl::GetFlag(FLAGS_int32_f));
|
||||
EXPECT_EQ(600, absl::GetFlag(FLAGS_int64_f));
|
||||
EXPECT_EQ(700, absl::GetFlag(FLAGS_uint64_f));
|
||||
EXPECT_FALSE(absl::GetFlag(FLAGS_bool_f));
|
||||
EXPECT_EQ(1, argc);
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsTest3) {
|
||||
const char *kFlags[] = {"program", "--bool_f", "--int32_f", "800"};
|
||||
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
ParseCommandLineFlags(kFlags[0], &argc, &argv);
|
||||
EXPECT_TRUE(absl::GetFlag(FLAGS_bool_f));
|
||||
EXPECT_EQ(800, absl::GetFlag(FLAGS_int32_f));
|
||||
EXPECT_EQ(1, argc);
|
||||
}
|
||||
|
||||
#ifndef _USE_EXTERNAL_ABSL
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsHelpTest) {
|
||||
const char *kFlags[] = {"program", "--help"};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
EXPECT_DEATH(ParseCommandLineFlags(kFlags[0], &argc, &argv), "");
|
||||
absl::SetFlag(&FLAGS_help, false);
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsVersionTest) {
|
||||
const char *kFlags[] = {"program", "--version"};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
EXPECT_DEATH(ParseCommandLineFlags(kFlags[0], &argc, &argv), "");
|
||||
absl::SetFlag(&FLAGS_version, false);
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsUnknownTest) {
|
||||
const char *kFlags[] = {"program", "--foo"};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
EXPECT_DEATH(ParseCommandLineFlags(kFlags[0], &argc, &argv), "");
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsInvalidBoolTest) {
|
||||
const char *kFlags[] = {"program", "--bool_f=X"};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
EXPECT_DEATH(ParseCommandLineFlags(kFlags[0], &argc, &argv), "");
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsEmptyStringArgs) {
|
||||
const char *kFlags[] = {"program", "--string_f="};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
ParseCommandLineFlags(kFlags[0], &argc, &argv);
|
||||
EXPECT_EQ(1, argc);
|
||||
EXPECT_EQ("", absl::GetFlag(FLAGS_string_f));
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsEmptyBoolArgs) {
|
||||
const char *kFlags[] = {"program", "--bool_f"};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
ParseCommandLineFlags(kFlags[0], &argc, &argv);
|
||||
EXPECT_EQ(1, argc);
|
||||
EXPECT_TRUE(absl::GetFlag(FLAGS_bool_f));
|
||||
}
|
||||
|
||||
TEST(FlagsTest, ParseCommandLineFlagsEmptyIntArgs) {
|
||||
const char *kFlags[] = {"program", "--int32_f"};
|
||||
int argc = arraysize(kFlags);
|
||||
char **argv = const_cast<char **>(kFlags);
|
||||
EXPECT_DEATH(ParseCommandLineFlags(kFlags[0], &argc, &argv), );
|
||||
}
|
||||
#endif // _USE_EXTERNAL_ABSL
|
||||
} // namespace absl
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue